@@ -906,58 +906,126 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
906
906
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
907
907
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
908
908
909
- func.func @bubble_up_pack_through_collapse (%1: tensor <192 x16 x64 x4 xf32 >) -> tensor <384 x256 x8 x1 xf32 > {
910
- %collapsed = tensor.collapse_shape %1 [[0 , 1 ], [2 , 3 ]] : tensor <192 x16 x64 x4 xf32 > into tensor <3072 x256 xf32 >
911
- %2 = tensor.empty () : tensor <384 x256 x8 x1 xf32 >
912
- %pack = tensor.pack %collapsed outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <3072 x256 xf32 > -> tensor <384 x256 x8 x1 xf32 >
913
- func.return %pack : tensor <384 x256 x8 x1 xf32 >
909
+ // -----
910
+
911
+ func.func @bubble_up_pack_through_collapse (%1: tensor <?x16 x4 xf32 >, %dim : index ) -> tensor <?x4 x8 x1 xf32 > {
912
+ %collapsed = tensor.collapse_shape %1 [[0 , 1 ], [2 ]] : tensor <?x16 x4 xf32 > into tensor <?x4 xf32 >
913
+ %2 = tensor.empty (%dim ) : tensor <?x4 x8 x1 xf32 >
914
+ %pack = tensor.pack %collapsed outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <?x4 xf32 > -> tensor <?x4 x8 x1 xf32 >
915
+ func.return %pack : tensor <?x4 x8 x1 xf32 >
914
916
}
917
+ // CHECK-LABEL: func.func @bubble_up_pack_through_collapse
918
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
919
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
920
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
921
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
922
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
923
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
924
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
925
+ // CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
926
+
927
+ // -----
915
928
916
929
func.func @bubble_up_permuted_pack_through_collapse (%1: tensor <4 x192 x16 x256 xf32 >) -> tensor <4 x32 x3072 x8 x1 xf32 > {
917
930
%collapsed = tensor.collapse_shape %1 [[0 ], [1 , 2 ], [3 ]] : tensor <4 x192 x16 x256 xf32 > into tensor <4 x3072 x256 xf32 >
918
931
%2 = tensor.empty () : tensor <4 x32 x3072 x8 x1 xf32 >
919
932
%pack = tensor.pack %collapsed outer_dims_perm = [0 , 2 , 1 ] inner_dims_pos = [2 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <4 x3072 x256 xf32 > -> tensor <4 x32 x3072 x8 x1 xf32 >
920
933
func.return %pack : tensor <4 x32 x3072 x8 x1 xf32 >
921
934
}
935
+ // CHECK-LABEL: func.func @bubble_up_permuted_pack_through_collapse
936
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
937
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x32x192x16x8x1xf32>
938
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<4x192x16x256xf32> -> tensor<4x32x192x16x8x1xf32>
939
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %pack {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x192x16x8x1xf32> into tensor<4x32x3072x8x1xf32>
940
+ // CHECK: return %[[COLLAPSED]] : tensor<4x32x3072x8x1xf32>
941
+
942
+ // -----
922
943
923
944
func.func @bubble_up_pack_through_unit_collapse (%1: tensor <1 x64 x1 x4 xf32 >) -> tensor <8 x4 x8 x1 xf32 > {
924
945
%collapsed = tensor.collapse_shape %1 [[0 , 1 , 2 ], [3 ]] : tensor <1 x64 x1 x4 xf32 > into tensor <64 x4 xf32 >
925
946
%2 = tensor.empty () : tensor <8 x4 x8 x1 xf32 >
926
947
%pack = tensor.pack %collapsed outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <64 x4 xf32 > -> tensor <8 x4 x8 x1 xf32 >
927
948
func.return %pack : tensor <8 x4 x8 x1 xf32 >
928
949
}
950
+ // CHECK-LABEL: func.func @bubble_up_pack_through_unit_collapse
951
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
952
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x4x8x1xf32>
953
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<1x64x1x4xf32> -> tensor<1x8x1x4x8x1xf32>
954
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<1x8x1x4x8x1xf32> into tensor<8x4x8x1xf32>
955
+ // CHECK: return %[[COLLAPSED]] : tensor<8x4x8x1xf32>
956
+
957
+ // -----
929
958
930
959
func.func @no_bubble_up_pack_through_non_divisible_collapse (%1: tensor <3072 x64 x4 xf32 >) -> tensor <384 x32 x8 x8 xf32 > {
931
960
%collapsed = tensor.collapse_shape %1 [[0 ], [1 , 2 ]] : tensor <3072 x64 x4 xf32 > into tensor <3072 x256 xf32 >
932
961
%2 = tensor.empty () : tensor <384 x32 x8 x8 xf32 >
933
962
%pack = tensor.pack %collapsed outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %2 : tensor <3072 x256 xf32 > -> tensor <384 x32 x8 x8 xf32 >
934
963
func.return %pack : tensor <384 x32 x8 x8 xf32 >
935
964
}
965
+ // CHECK-LABEL: func.func @no_bubble_up_pack_through_non_divisible_collapse
966
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
967
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
968
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[COLLAPSED]]
969
+ // CHECK: return %[[PACK]] : tensor<384x32x8x8xf32>
936
970
937
- func.func @push_down_unpack_through_expand (%5: tensor <384 x32 x8 x8 xf32 >) -> tensor <12 x256 x256 xf32 > {
938
- %6 = tensor.empty () : tensor <3072 x256 xf32 >
939
- %unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <384 x32 x8 x8 xf32 > -> tensor <3072 x256 xf32 >
940
- %expanded = tensor.expand_shape %unpack [[0 , 1 ], [2 ]] : tensor <3072 x256 xf32 > into tensor <12 x256 x256 xf32 >
941
- func.return %expanded : tensor <12 x256 x256 xf32 >
971
+ // -----
972
+
973
+ func.func @push_down_unpack_through_expand (%5: tensor <?x32 x8 x8 xf32 >, %dim: index ) -> tensor <?x256 x256 xf32 > {
974
+ %6 = tensor.empty (%dim ) : tensor <?x256 xf32 >
975
+ %unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <?x32 x8 x8 xf32 > -> tensor <?x256 xf32 >
976
+ %expanded = tensor.expand_shape %unpack [[0 , 1 ], [2 ]] : tensor <?x256 xf32 > into tensor <?x256 x256 xf32 >
977
+ func.return %expanded : tensor <?x256 x256 xf32 >
942
978
}
979
+ // CHECK-LABEL: func.func @push_down_unpack_through_expand
980
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
981
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
982
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
983
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
984
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
985
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
986
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
987
+ // CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
988
+
989
+ // -----
943
990
944
991
func.func @push_down_permuted_unpack_through_expand (%5: tensor <4 x32 x384 x8 x8 xf32 >) -> tensor <4 x12 x256 x256 xf32 > {
945
992
%6 = tensor.empty () : tensor <4 x3072 x256 xf32 >
946
993
%unpack = tensor.unpack %5 outer_dims_perm = [0 , 2 , 1 ] inner_dims_pos = [2 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <4 x32 x384 x8 x8 xf32 > -> tensor <4 x3072 x256 xf32 >
947
994
%expanded = tensor.expand_shape %unpack [[0 ], [1 , 2 ], [3 ]] : tensor <4 x3072 x256 xf32 > into tensor <4 x12 x256 x256 xf32 >
948
995
func.return %expanded : tensor <4 x12 x256 x256 xf32 >
949
996
}
997
+ // CHECK-LABEL: @push_down_permuted_unpack_through_expand
998
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
999
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
1000
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
1001
+ // CHECK: %[[UNPACL:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
1002
+ // CHECK: return %[[UNPACK]] : tensor<4x12x256x256xf32>
1003
+
1004
+ // -----
950
1005
951
1006
func.func @push_down_unpack_through_unit_expand (%5: tensor <6 x32 x8 x8 xf32 >) -> tensor <3 x16 x1 x256 xf32 > {
952
1007
%6 = tensor.empty () : tensor <48 x256 xf32 >
953
1008
%unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <6 x32 x8 x8 xf32 > -> tensor <48 x256 xf32 >
954
1009
%expanded = tensor.expand_shape %unpack [[0 , 1 , 2 ], [3 ]] : tensor <48 x256 xf32 > into tensor <3 x16 x1 x256 xf32 >
955
1010
func.return %expanded : tensor <3 x16 x1 x256 xf32 >
956
1011
}
1012
+ // CHECK-LABEL: func.func @push_down_unpack_through_unit_expand
1013
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1014
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
1015
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32>
1016
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32>
1017
+ // CHECK: return %[[UNPACK]] : tensor<3x16x1x256xf32>
1018
+
1019
+ // -----
957
1020
958
1021
func.func @no_push_down_unpack_through_non_divisible_expand (%5: tensor <384 x32 x8 x8 xf32 >) -> tensor <256 x12 x256 xf32 > {
959
1022
%6 = tensor.empty () : tensor <3072 x256 xf32 >
960
1023
%unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <384 x32 x8 x8 xf32 > -> tensor <3072 x256 xf32 >
961
1024
%expanded = tensor.expand_shape %unpack [[0 , 1 ], [2 ]] : tensor <3072 x256 xf32 > into tensor <256 x12 x256 xf32 >
962
1025
func.return %expanded : tensor <256 x12 x256 xf32 >
963
1026
}
1027
+ // CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand
1028
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1029
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
1030
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
1031
+ // CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
0 commit comments