Skip to content

Commit 8d08166

Browse files
authored
[MLIR][Tensor] Fix source/dest type check in UnPackOp canonicalize (#106094)
Fix `RankedTensorType` equality check in unpack op canonicalization.
1 parent b55186e commit 8d08166

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4203,7 +4203,7 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
42034203
}
42044204

42054205
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4206-
// Fold an unpack(pack(x)) to x.
4206+
// Fold an pack(unpack(x)) to x.
42074207
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
42084208
if (unPackOp.getSourceType() != packOp.getDestType())
42094209
return failure();
@@ -4437,9 +4437,9 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
44374437

44384438
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
44394439
PatternRewriter &rewriter) {
4440-
/// pack(unpack(x)) -> x
4440+
/// unpack(pack(x)) -> x
44414441
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4442-
if (packOp.getDestType() != unPackOp.getSourceType())
4442+
if (packOp.getSourceType() != unPackOp.getDestType())
44434443
return failure();
44444444
if (packOp.getPaddingValue() ||
44454445
!hasSameInnerOuterAttribute(packOp, unPackOp) ||

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,6 +2268,19 @@ func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) ->
22682268

22692269
// -----
22702270

2271+
// CHECK: func.func @unpack_pack_with_padding_no_canonicalization(
2272+
// CHECK: tensor.pack
2273+
// CHECK: tensor.unpack
2274+
func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> {
2275+
%tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
2276+
%tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
2277+
%packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
2278+
%unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
2279+
return %unpacked : tensor<224x512xbf16>
2280+
}
2281+
2282+
// -----
2283+
22712284
// Chain NCnc -> NC -> NC -> NCnc
22722285
// CHECK: func.func @pack_unpack(
22732286
// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,

0 commit comments

Comments
 (0)