@@ -905,3 +905,59 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
905
905
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
906
906
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
907
907
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
908
+
909
+ func.func @bubble_up_pack_through_collapse (%1: tensor <192 x16 x64 x4 xf32 >) -> tensor <384 x256 x8 x1 xf32 > {
910
+ %collapsed = tensor.collapse_shape %1 [[0 , 1 ], [2 , 3 ]] : tensor <192 x16 x64 x4 xf32 > into tensor <3072 x256 xf32 >
911
+ %2 = tensor.empty () : tensor <384 x256 x8 x1 xf32 >
912
+ %pack = tensor.pack %collapsed outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <3072 x256 xf32 > -> tensor <384 x256 x8 x1 xf32 >
913
+ func.return %pack : tensor <384 x256 x8 x1 xf32 >
914
+ }
915
+
916
+ func.func @bubble_up_permuted_pack_through_collapse (%1: tensor <4 x192 x16 x256 xf32 >) -> tensor <4 x32 x3072 x8 x1 xf32 > {
917
+ %collapsed = tensor.collapse_shape %1 [[0 ], [1 , 2 ], [3 ]] : tensor <4 x192 x16 x256 xf32 > into tensor <4 x3072 x256 xf32 >
918
+ %2 = tensor.empty () : tensor <4 x32 x3072 x8 x1 xf32 >
919
+ %pack = tensor.pack %collapsed outer_dims_perm = [0 , 2 , 1 ] inner_dims_pos = [2 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <4 x3072 x256 xf32 > -> tensor <4 x32 x3072 x8 x1 xf32 >
920
+ func.return %pack : tensor <4 x32 x3072 x8 x1 xf32 >
921
+ }
922
+
923
+ func.func @bubble_up_pack_through_unit_collapse (%1: tensor <1 x64 x1 x4 xf32 >) -> tensor <8 x4 x8 x1 xf32 > {
924
+ %collapsed = tensor.collapse_shape %1 [[0 , 1 , 2 ], [3 ]] : tensor <1 x64 x1 x4 xf32 > into tensor <64 x4 xf32 >
925
+ %2 = tensor.empty () : tensor <8 x4 x8 x1 xf32 >
926
+ %pack = tensor.pack %collapsed outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <64 x4 xf32 > -> tensor <8 x4 x8 x1 xf32 >
927
+ func.return %pack : tensor <8 x4 x8 x1 xf32 >
928
+ }
929
+
930
+ func.func @no_bubble_up_pack_through_non_divisible_collapse (%1: tensor <3072 x64 x4 xf32 >) -> tensor <384 x32 x8 x8 xf32 > {
931
+ %collapsed = tensor.collapse_shape %1 [[0 ], [1 , 2 ]] : tensor <3072 x64 x4 xf32 > into tensor <3072 x256 xf32 >
932
+ %2 = tensor.empty () : tensor <384 x32 x8 x8 xf32 >
933
+ %pack = tensor.pack %collapsed outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %2 : tensor <3072 x256 xf32 > -> tensor <384 x32 x8 x8 xf32 >
934
+ func.return %pack : tensor <384 x32 x8 x8 xf32 >
935
+ }
936
+
937
+ func.func @push_down_unpack_through_expand (%5: tensor <384 x32 x8 x8 xf32 >) -> tensor <12 x256 x256 xf32 > {
938
+ %6 = tensor.empty () : tensor <3072 x256 xf32 >
939
+ %unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <384 x32 x8 x8 xf32 > -> tensor <3072 x256 xf32 >
940
+ %expanded = tensor.expand_shape %unpack [[0 , 1 ], [2 ]] : tensor <3072 x256 xf32 > into tensor <12 x256 x256 xf32 >
941
+ func.return %expanded : tensor <12 x256 x256 xf32 >
942
+ }
943
+
944
+ func.func @push_down_permuted_unpack_through_expand (%5: tensor <4 x32 x384 x8 x8 xf32 >) -> tensor <4 x12 x256 x256 xf32 > {
945
+ %6 = tensor.empty () : tensor <4 x3072 x256 xf32 >
946
+ %unpack = tensor.unpack %5 outer_dims_perm = [0 , 2 , 1 ] inner_dims_pos = [2 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <4 x32 x384 x8 x8 xf32 > -> tensor <4 x3072 x256 xf32 >
947
+ %expanded = tensor.expand_shape %unpack [[0 ], [1 , 2 ], [3 ]] : tensor <4 x3072 x256 xf32 > into tensor <4 x12 x256 x256 xf32 >
948
+ func.return %expanded : tensor <4 x12 x256 x256 xf32 >
949
+ }
950
+
951
+ func.func @push_down_unpack_through_unit_expand (%5: tensor <6 x32 x8 x8 xf32 >) -> tensor <3 x16 x1 x256 xf32 > {
952
+ %6 = tensor.empty () : tensor <48 x256 xf32 >
953
+ %unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <6 x32 x8 x8 xf32 > -> tensor <48 x256 xf32 >
954
+ %expanded = tensor.expand_shape %unpack [[0 , 1 , 2 ], [3 ]] : tensor <48 x256 xf32 > into tensor <3 x16 x1 x256 xf32 >
955
+ func.return %expanded : tensor <3 x16 x1 x256 xf32 >
956
+ }
957
+
958
+ func.func @no_push_down_unpack_through_non_divisible_expand (%5: tensor <384 x32 x8 x8 xf32 >) -> tensor <256 x12 x256 xf32 > {
959
+ %6 = tensor.empty () : tensor <3072 x256 xf32 >
960
+ %unpack = tensor.unpack %5 outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <384 x32 x8 x8 xf32 > -> tensor <3072 x256 xf32 >
961
+ %expanded = tensor.expand_shape %unpack [[0 , 1 ], [2 ]] : tensor <3072 x256 xf32 > into tensor <256 x12 x256 xf32 >
962
+ func.return %expanded : tensor <256 x12 x256 xf32 >
963
+ }
0 commit comments