Skip to content

Commit 78348b6

Browse files
authored
[mlir][tensor] Improve tensor.pack simplication pattern. (llvm#76606)
A tensor.pack op can be rewritten to a tensor.expand_shape op if the packing only happens on inner most dimension. This also formats the lit checks better.
1 parent cab156c commit 78348b6

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,20 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
3535

3636
LogicalResult matchAndRewrite(PackOp packOp,
3737
PatternRewriter &rewriter) const override {
38+
if (packOp.getPaddingValue())
39+
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
40+
41+
if (!packOp.getOuterDimsPerm().empty())
42+
return rewriter.notifyMatchFailure(packOp, "expects no outer_dims_perm");
43+
3844
RankedTensorType sourceType = packOp.getSourceType();
3945
RankedTensorType destType = packOp.getDestType();
40-
if (sourceType.getRank() != 1 || packOp.getPaddingValue())
41-
return failure();
46+
ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
47+
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
48+
return rewriter.notifyMatchFailure(
49+
packOp, "expects packing at the innermost dimension");
50+
}
51+
4252
auto reassociation =
4353
getReassociationIndicesForReshape(sourceType, destType);
4454
if (!reassociation)
Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
22

3-
// CHECK: func.func @single_dim_packing(
4-
// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
5-
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
6-
// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
3+
// CHECK-LABEL: func.func @single_dim_packing(
4+
// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
5+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
6+
// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
77
func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
88
%empty = tensor.empty() : tensor<8x32xf32>
99
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
@@ -12,13 +12,47 @@ func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
1212

1313
// -----
1414

15-
// CHECK: func.func @single_dim_packing_with_padding(
16-
// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
17-
// CHECK-NOT: tensor.expand_shape
18-
// CHECK: tensor.pack
15+
// CHECK-LABEL: func.func @single_dim_packing_with_padding(
16+
// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
17+
// CHECK-NOT: tensor.expand_shape
18+
// CHECK: tensor.pack
1919
func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
2020
%empty = tensor.empty() : tensor<8x32xf32>
2121
%cst = arith.constant 0.000000e+00 : f32
2222
%0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
2323
return %0 : tensor<8x32xf32>
2424
}
25+
26+
// -----
27+
28+
// CHECK-LABEL: func.func @single_last_inner_dim_packing(
29+
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>)
30+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
31+
// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32>
32+
func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
33+
%empty = tensor.empty() : tensor<5x8x32xf32>
34+
%0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32>
35+
return %0 : tensor<5x8x32xf32>
36+
}
37+
38+
// -----
39+
40+
// CHECK-LABEL: func.func @packing_with_outer_dims_perm(
41+
// CHECK-NOT: tensor.expand_shape
42+
// CHECK: tensor.pack
43+
func.func @packing_with_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<8x5x32xf32> {
44+
%empty = tensor.empty() : tensor<8x5x32xf32>
45+
%0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<8x5x32xf32>
46+
return %0 : tensor<8x5x32xf32>
47+
}
48+
49+
// -----
50+
51+
// CHECK-LABEL: func.func @single_first_inner_dim_packing(
52+
// CHECK-NOT: tensor.expand_shape
53+
// CHECK: tensor.pack
54+
func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x5x32xf32> {
55+
%empty = tensor.empty() : tensor<8x5x32xf32>
56+
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
57+
return %0 : tensor<8x5x32xf32>
58+
}

0 commit comments

Comments
 (0)