@@ -956,6 +956,24 @@ func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> ten
956
956
957
957
// -----
958
958
959
+ func.func @bubble_up_pack_through_collapse_on_outer_dims (%1: tensor <?x16 x4 xf32 >, %dim : index ) -> tensor <?x1 x4 xf32 > {
960
+ %collapsed = tensor.collapse_shape %1 [[0 , 1 ], [2 ]] : tensor <?x16 x4 xf32 > into tensor <?x4 xf32 >
961
+ %2 = tensor.empty (%dim ) : tensor <?x1 x4 xf32 >
962
+ %pack = tensor.pack %collapsed outer_dims_perm = [0 , 1 ] inner_dims_pos = [1 ] inner_tiles = [4 ] into %2 : tensor <?x4 xf32 > -> tensor <?x1 x4 xf32 >
963
+ func.return %pack : tensor <?x1 x4 xf32 >
964
+ }
965
+ // CHECK-LABEL: func.func @bubble_up_pack_through_collapse_on_outer_dims
966
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
967
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
968
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
969
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
970
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x16x1x4xf32>
971
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [4] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x16x1x4xf32>
972
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] : tensor<?x16x1x4xf32> into tensor<?x1x4xf32>
973
+ // CHECK: return %[[COLLAPSED]] : tensor<?x1x4xf32>
974
+
975
+ // -----
976
+
959
977
func.func @no_bubble_up_pack_through_non_divisible_collapse (%1: tensor <3072 x64 x4 xf32 >) -> tensor <384 x32 x8 x8 xf32 > {
960
978
%collapsed = tensor.collapse_shape %1 [[0 ], [1 , 2 ]] : tensor <3072 x64 x4 xf32 > into tensor <3072 x256 xf32 >
961
979
%2 = tensor.empty () : tensor <384 x32 x8 x8 xf32 >
@@ -1018,6 +1036,24 @@ func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> ten
1018
1036
1019
1037
// -----
1020
1038
1039
+ func.func @push_down_unpack_through_expand_on_outer_dims (%5: tensor <?x32 x8 xf32 >, %dim: index ) -> tensor <?x256 x256 xf32 > {
1040
+ %6 = tensor.empty (%dim ) : tensor <?x256 xf32 >
1041
+ %unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [1 ] inner_tiles = [8 ] into %6 : tensor <?x32 x8 xf32 > -> tensor <?x256 xf32 >
1042
+ %expanded = tensor.expand_shape %unpack [[0 , 1 ], [2 ]] : tensor <?x256 xf32 > into tensor <?x256 x256 xf32 >
1043
+ func.return %expanded : tensor <?x256 x256 xf32 >
1044
+ }
1045
+ // CHECK-LABEL: func.func @push_down_unpack_through_expand_on_outer_dims
1046
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1047
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1048
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
1049
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32>
1050
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x256x32x8xf32>
1051
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
1052
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [8] into %[[EMPTY]] : tensor<?x256x32x8xf32> -> tensor<?x256x256xf32>
1053
+ // CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
1054
+
1055
+ // -----
1056
+
1021
1057
func.func @no_push_down_unpack_through_non_divisible_expand (%5: tensor <384 x32 x8 x8 xf32 >) -> tensor <256 x12 x256 xf32 > {
1022
1058
%6 = tensor.empty () : tensor <3072 x256 xf32 >
1023
1059
%unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <384 x32 x8 x8 xf32 > -> tensor <3072 x256 xf32 >
0 commit comments