Skip to content

Commit 2f66765

Browse files
committed
[mlir][tensor] Relax the logic to generalise tensor.pack
Make sure that the logic to generalize tensor.pack (into e.g. tensor.pad tensor.transpose) does indeed allow multiple dynamic tile sizes. This was effectively already implemented in llvm#109815 - in this PR I merely removing one `if` condition and adding a test. I also took the liberty of renaming a few test functions - just to better highlight the differences between the old and the new tests. Follow-on for llvm#109815.
1 parent 66f84c8 commit 2f66765

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,12 +1145,6 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11451145

11461146
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11471147
tensor::PackOp packOp, PatternRewriter &rewriter) const {
1148-
if (llvm::count_if(packOp.getMixedTiles(),
1149-
[](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
1150-
return rewriter.notifyMatchFailure(
1151-
packOp, "at most one dynamic tile size is supported");
1152-
}
1153-
11541148
// TODO: support the case that outer dimensions are not all 1s. A
11551149
// tensor.expand_shape will be generated in this case.
11561150
if (llvm::any_of(packOp.getTiledOuterDims(),

mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
1919

2020
// -----
2121

22-
func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> {
22+
func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> {
2323
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
2424
return %0 : tensor<1x1x8x2xf32>
2525
}
2626
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>
27+
// CHECK: #[[$ATTR_1:.+]] = affine_map<()[s0] -> (s0 - 1)>
2728

28-
// CHECK-LABEL: func.func @simple_pad_and_pack
29+
// CHECK-LABEL: func.func @simple_pad_and_pack_static_tiles
2930
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
3031
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
3132
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
@@ -36,18 +37,18 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
3637
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
3738
// CHECK: return %[[INSERT]]
3839

39-
/// Same as example above, but with dynamic tile size.
40+
/// Same as example above, but with 1 dynamic tile size.
4041

41-
func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
42+
func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
4243
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
4344
return %0 : tensor<1x1x?x2xf32>
4445
}
4546

46-
// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic(
47+
// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile(
4748
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
4849
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
4950
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
50-
// CHECK-SAME: %[[HIGH_VAL:.*]]: index) -> tensor<1x1x?x2xf32> {
51+
// CHECK-SAME: %[[HIGH_VAL:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> {
5152
// CHECK: %[[C2:.*]] = arith.constant 2 : index
5253
// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
5354
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
@@ -58,21 +59,21 @@ func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<
5859
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
5960
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
6061

61-
/// Same as example above, but with scalable tile size.
62+
/// Same as example above, but with 1 scalable tile size.
6263

6364
/// NOTE: For this example to make sense in practice, the "?" in the output shape
6465
/// should effectively be 8 * vector.vscale (and that's what tensor.dim
6566
/// below should return).
6667

67-
func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
68+
func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
6869
%c8 = arith.constant 8 : index
6970
%vscale = vector.vscale
7071
%c8_vscale = arith.muli %vscale, %c8 : index
7172
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
7273
return %0 : tensor<1x1x?x2xf32>
7374
}
7475

75-
// CHECK-LABEL: func.func @simple_pad_and_pack_scalable(
76+
// CHECK-LABEL: func.func @simple_pad_and_pack_scalable_tile(
7677
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
7778
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
7879
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
@@ -89,6 +90,31 @@ func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor
8990
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
9091
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
9192

93+
/// Same as example above, but with both tile sizes dynamic.
94+
95+
func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %high_1: index, %high_2: index) -> tensor<1x1x?x?xf32> {
96+
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high_1, %high_2] into %output : tensor<5x1xf32> -> tensor<1x1x?x?xf32>
97+
return %0 : tensor<1x1x?x?xf32>
98+
}
99+
// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tiles(
100+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
101+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x?xf32>,
102+
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32,
103+
// CHECK-SAME: %[[HIGH_VAL_1:[a-zA-Z0-9]+]]: index,
104+
// CHECK-SAME: %[[HIGH_VAL_2:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x?xf32> {
105+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
106+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
107+
// CHECK: %[[PAD_HIGH_1:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL_1]]]
108+
// CHECK: %[[PAD_HIGH_2:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[HIGH_VAL_2]]]
109+
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] {
110+
// CHECK: tensor.yield %[[PAD_VAL]] : f32
111+
// CHECK-NOT: linalg.transpose
112+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[HIGH_VAL_1]], %[[HIGH_VAL_2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
113+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x?xf32>
114+
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[DEST]], %[[C3]] : tensor<1x1x?x?xf32>
115+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM_1]], %[[DIM_2]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
116+
// CHECK: return %[[RES]] : tensor<1x1x?x?xf32>
117+
92118
// -----
93119

94120
func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{

0 commit comments

Comments
 (0)