Skip to content

Commit 86024e8

Browse files
author
Jerry Wu
committed
Add new tests
1 parent 143243f commit 86024e8

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,24 @@ func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> ten
956956

957957
// -----
958958

959+
func.func @bubble_up_pack_through_collapse_on_outer_dims(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x1x4xf32> {
960+
%collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
961+
%2 = tensor.empty(%dim) : tensor<?x1x4xf32>
962+
%pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [4] into %2 : tensor<?x4xf32> -> tensor<?x1x4xf32>
963+
func.return %pack : tensor<?x1x4xf32>
964+
}
965+
// CHECK-LABEL: func.func @bubble_up_pack_through_collapse_on_outer_dims
966+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
967+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
968+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
969+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
970+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x16x1x4xf32>
971+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [4] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x16x1x4xf32>
972+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] : tensor<?x16x1x4xf32> into tensor<?x1x4xf32>
973+
// CHECK: return %[[COLLAPSED]] : tensor<?x1x4xf32>
974+
975+
// -----
976+
959977
func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
960978
%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
961979
%2 = tensor.empty() : tensor<384x32x8x8xf32>
@@ -1018,6 +1036,24 @@ func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> ten
10181036

10191037
// -----
10201038

1039+
func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
1040+
%6 = tensor.empty(%dim) : tensor<?x256xf32>
1041+
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [8] into %6 : tensor<?x32x8xf32> -> tensor<?x256xf32>
1042+
%expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
1043+
func.return %expanded : tensor<?x256x256xf32>
1044+
}
1045+
// CHECK-LABEL: func.func @push_down_unpack_through_expand_on_outer_dims
1046+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1047+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1048+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
1049+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32>
1050+
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x256x32x8xf32>
1051+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
1052+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [8] into %[[EMPTY]] : tensor<?x256x32x8xf32> -> tensor<?x256x256xf32>
1053+
// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
1054+
1055+
// -----
1056+
10211057
func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
10221058
%6 = tensor.empty() : tensor<3072x256xf32>
10231059
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>

0 commit comments

Comments
 (0)