Skip to content

Commit ff6aa3a

Browse files
author
Jerry Wu
committed
Add test draft
1 parent a0b00a0 commit ff6aa3a

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,3 +905,59 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
905905
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
906906
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
907907
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
908+
909+
func.func @bubble_up_pack_through_collapse(%1: tensor<192x16x64x4xf32>) -> tensor<384x256x8x1xf32> {
910+
%collapsed = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<192x16x64x4xf32> into tensor<3072x256xf32>
911+
%2 = tensor.empty() : tensor<384x256x8x1xf32>
912+
%pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<3072x256xf32> -> tensor<384x256x8x1xf32>
913+
func.return %pack : tensor<384x256x8x1xf32>
914+
}
915+
916+
func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
917+
%collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
918+
%2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
919+
%pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
920+
func.return %pack : tensor<4x32x3072x8x1xf32>
921+
}
922+
923+
func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
924+
%collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
925+
%2 = tensor.empty() : tensor<8x4x8x1xf32>
926+
%pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
927+
func.return %pack : tensor<8x4x8x1xf32>
928+
}
929+
930+
func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
931+
%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
932+
%2 = tensor.empty() : tensor<384x32x8x8xf32>
933+
%pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
934+
func.return %pack : tensor<384x32x8x8xf32>
935+
}
936+
937+
func.func @push_down_unpack_through_expand(%5: tensor<384x32x8x8xf32>) -> tensor<12x256x256xf32> {
938+
%6 = tensor.empty() : tensor<3072x256xf32>
939+
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
940+
%expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<12x256x256xf32>
941+
func.return %expanded : tensor<12x256x256xf32>
942+
}
943+
944+
func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
945+
%6 = tensor.empty() : tensor<4x3072x256xf32>
946+
%unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
947+
%expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
948+
func.return %expanded : tensor<4x12x256x256xf32>
949+
}
950+
951+
func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
952+
%6 = tensor.empty() : tensor<48x256xf32>
953+
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32>
954+
%expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
955+
func.return %expanded : tensor<3x16x1x256xf32>
956+
}
957+
958+
func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
959+
%6 = tensor.empty() : tensor<3072x256xf32>
960+
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
961+
%expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
962+
func.return %expanded : tensor<256x12x256xf32>
963+
}

0 commit comments

Comments
 (0)