Skip to content

Commit 3eb7cce

Browse files
authored
[mlir][linalg] Add tests for tensor.unpack decomposition (llvm#118786)
This commit adds additional tests and documentation for `DecomposeOuterUnitDimsUnPackOpPattern` to ensure symmetry with its counterpart for `tensor.pack`, `DecomposeOuterUnitDimsPackOpPattern`. The new tests aim to improve implementation, documentation, and test coverage for tensor.unpack. They cover the following scenarios: * Static tile sizes: A simple `tensor.unpack` case (`@simple_unpack_static_tiles`). * Dynamic tile size: `tensor.unpack` with a single dynamic tile size (`@simple_unpack_dynamic_tile`). * Transpose: `tensor.unpack` with dynamic tile size and transpose (`@simple_unpack_dynamic_tile_transpose`), currently commented out due to some missing logic (see below) * Scalable tile size: `tensor.unpack` with a scalable inner tile size (@simple_unpack_scalable_tile). Notes: The test `@simple_unpack_dynamic_tile_transpose` is commented out because the logic for capturing dynamic sizes for `tensor::EmptyOp` when some tile sizes are dynamic is incomplete. This missing functionality will be addressed in a follow-up patch.
1 parent 6f5bffd commit 3eb7cce

File tree

3 files changed

+82
-6
lines changed

3 files changed

+82
-6
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,7 +1519,7 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
15191519
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
15201520
/// tensor::InsertSliceOp ops.
15211521
///
1522-
/// Required that all the outer dims of the input tensor::PackOp are 1.
1522+
/// Requires that all the outer dims of the input tensor::PackOp are 1.
15231523
///
15241524
/// Before:
15251525
/// ```
@@ -1555,9 +1555,33 @@ struct DecomposeOuterUnitDimsPackOpPattern
15551555
PatternRewriter &rewriter) const override;
15561556
};
15571557

1558-
/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
1559-
/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
1560-
/// being all 1s.
1558+
/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced
1559+
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
1560+
///
1561+
/// Requires that all the outer dims of the input tensor::PackOp are 1.
1562+
///
1563+
/// Before:
1564+
/// ```
1565+
/// %packed = tensor.unpack %input
1566+
/// inner_dims_pos = [1, 0]
1567+
/// inner_tiles = [2, 8]
1568+
/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
1569+
/// ```
1570+
///
1571+
/// After:
1572+
/// ```
1573+
/// // Rank-reduced extract to obtain the tile
1574+
/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1]
1575+
/// : tensor<1x1x2x8xf32> to tensor<2x8xf32>
1576+
/// // EmptyOp + TransposeOp
1577+
/// %init = tensor.empty() : tensor<8x2xf32>
1578+
/// %transposed = linalg.transpose
1579+
/// ins(%extracted_slice : tensor<2x8xf32>)
1580+
/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0]
1581+
/// // Extract a slice matching the specified output size
1582+
/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1]
1583+
/// : tensor<8x2xf32> to tensor<5x1xf32>
1584+
/// ```
15611585
struct DecomposeOuterUnitDimsUnPackOpPattern
15621586
: public OpRewritePattern<tensor::UnPackOp> {
15631587
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;

mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: te
6767
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
6868
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
6969

70+
/// Same as example above, but the dynamic tile size is a compile-time constant
71+
/// that's folded away.
72+
7073
func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
7174
%tile_dim_0 = arith.constant 8 : index
7275
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>

mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<
1919

2020
// -----
2121

22-
func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> {
22+
func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> {
2323
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<1x1x8x2xf32> -> tensor<5x1xf32>
2424
return %0 : tensor<5x1xf32>
2525
}
26-
// CHECK-LABEL: func.func @simple_unpack_and_extract_slice
26+
// CHECK-LABEL: func.func @simple_unpack_static_tiles
2727
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
2828
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
2929
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
@@ -33,6 +33,55 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
3333
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
3434
// CHECK: return %[[SLICE]]
3535

36+
/// Same as example above, but with 1 dynamic tile size.
37+
38+
func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
39+
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
40+
return %0 : tensor<5x1xf32>
41+
}
42+
// CHECK-LABEL: func.func @simple_unpack_dynamic_tile
43+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
44+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
45+
// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]
46+
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_1]], 2] [1, 1, 1, 1]
47+
// CHECK-NOT: linalg.transpose
48+
// They have the same type, so the insert_slice op is folded
49+
// away.
50+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
51+
// CHECK: return %[[SLICE]]
52+
53+
/// Same as example above, but with 1 dynamic tile size and a trasnpose
54+
55+
/// FIXME: This is currently broken:
56+
/// * 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1
57+
58+
//func.func @simple_unpack_dynamic_tile_transpose(%input: tensor<1x1x2x?xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
59+
// %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_0] into %output : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
60+
// return %0 : tensor<5x1xf32>
61+
//}
62+
63+
/// Same as example above, but with 1 scalable tile size.
64+
65+
func.func @simple_unpack_scalable_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> {
66+
%c8 = arith.constant 8 : index
67+
%vscale = vector.vscale
68+
%c8_vscale = arith.muli %vscale, %c8 : index
69+
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
70+
return %0 : tensor<5x1xf32>
71+
}
72+
// CHECK-LABEL: func.func @simple_unpack_scalable_tile
73+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
74+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
75+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
76+
// CHECK-DAG: %[[VS:.+]] = vector.vscale
77+
// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
78+
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1]
79+
// CHECK-NOT: linalg.transpose
80+
// They have the same type, so the insert_slice op is folded
81+
// away.
82+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
83+
// CHECK: return %[[SLICE]]
84+
3685
// -----
3786

3887
func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32>) -> tensor<32x8xf32>{

0 commit comments

Comments
 (0)