@@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
455
455
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
456
456
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
457
457
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
458
- // CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
459
- // CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
460
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
461
- // CHECK-SAME: into %[[ARG0_EMPTY_PACK]]
458
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
462
459
// CHECK: %[[RES:.+]] = linalg.generic
463
460
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
464
- // CHECK-SAME: outs(%[[PACKED_ARG0 ]]
461
+ // CHECK-SAME: outs(%[[EMPTY ]]
465
462
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
466
463
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
467
464
// CHECK-SAME: into %[[UNPACKED_ARG0]]
@@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
485
482
// CHECK-LABEL: func.func @unpack_on_input
486
483
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
487
484
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
488
- // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
489
- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
490
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
491
- // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
492
- // CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
493
- // CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
494
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
495
- // CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
496
- // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
497
- // CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
498
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
499
- // CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
485
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
500
486
// CHECK: %[[RES:.+]] = linalg.generic
501
487
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
502
- // CHECK-SAME: ins(%[[ARG0_PACK ]]
503
- // CHECK-SAME: outs(%[[ARG1_PACK ]]
488
+ // CHECK-SAME: ins(%[[ARG0 ]]
489
+ // CHECK-SAME: outs(%[[EMPTY ]]
504
490
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
505
491
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
506
492
// CHECK-SAME: into %[[ARG1]]
@@ -524,22 +510,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
524
510
// CHECK-LABEL: func.func @unpack_element_type_change
525
511
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
526
512
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
527
- // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
528
- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
529
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
530
- // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
531
- // CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
532
- // CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
533
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
534
- // CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
535
- // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
536
- // CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
537
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
538
- // CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
513
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
539
514
// CHECK: %[[RES:.+]] = linalg.generic
540
515
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
541
- // CHECK-SAME: ins(%[[ARG0_PACK ]]
542
- // CHECK-SAME: outs(%[[ARG1_PACK ]]
516
+ // CHECK-SAME: ins(%[[ARG0 ]]
517
+ // CHECK-SAME: outs(%[[EMPTY ]]
543
518
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
544
519
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
545
520
// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +539,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
564
539
// CHECK-LABEL: func.func @forward_tensor_empty
565
540
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
566
541
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
567
- // CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
568
- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
569
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
570
- // CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
571
- // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
572
- // CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
573
- // CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
574
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
575
- // CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
542
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
576
543
// CHECK: %[[RES:.+]] = linalg.generic
577
544
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
578
- // CHECK-SAME: ins(%[[PACKED_ARG0 ]]
579
- // CHECK-SAME: outs(%[[DEST ]]
545
+ // CHECK-SAME: ins(%[[ARG0 ]]
546
+ // CHECK-SAME: outs(%[[EMPTY ]]
580
547
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
581
548
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
582
549
// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +777,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
810
777
}
811
778
812
779
// CHECK-LABEL: func.func @unpack_empty_inner_dims
813
- // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
814
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
815
- // CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
816
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
780
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
817
781
// CHECK: %[[RES:.+]] = linalg.generic
818
- // CHECK-SAME: ins(%[[PACKED_ARG0 ]]
782
+ // CHECK-SAME: ins(%[[ARG0 ]]
819
783
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
820
784
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
821
785
@@ -943,14 +907,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
943
907
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
944
908
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
945
909
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
946
- // CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
947
- // CHECK: %[[PACK_ARG0:.+]] = linalg.pack
948
- // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
949
- // CHECK-SAME: into %[[PACK_EMPTY]]
950
910
// CHECK: %[[POOL:.+]] = linalg.generic
951
911
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
952
912
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
953
- // CHECK-SAME: ins(%[[PACK_ARG0 ]], %[[ARG1]]
913
+ // CHECK-SAME: ins(%[[ARG0 ]], %[[ARG1]]
954
914
// CHECK-SAME: outs(%[[INIT]]
955
915
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
956
916
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1381,48 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
1421
1381
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
1422
1382
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
1423
1383
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
1384
+
1385
+ // -----
1386
+
1387
+ func.func @push_unpack_in_padded_domain_foldable (%arg0: tensor <8 x8 x4 x8 xf32 >, %dest: tensor <?x64 xf32 >, %arg1: tensor <?x64 xbf16 >) -> tensor <?x64 xbf16 > {
1388
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [4 , 8 ] into %dest : tensor <8 x8 x4 x8 xf32 > -> tensor <?x64 xf32 >
1389
+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%unpack : tensor <?x64 xf32 >) outs (%arg1 : tensor <?x64 xbf16 >) {
1390
+ ^bb0 (%in: f32 , %out: bf16 ):
1391
+ %1 = arith.truncf %in : f32 to bf16
1392
+ linalg.yield %1 : bf16
1393
+ } -> tensor <?x64 xbf16 >
1394
+ return %0 : tensor <?x64 xbf16 >
1395
+ }
1396
+ // CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable
1397
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1398
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1399
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1400
+ // CHECK: %[[EMPTY:.+]] = tensor.empty
1401
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
1402
+ // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1403
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
1404
+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
1405
+ // CHECK-SAME: into %[[ARG2]]
1406
+ // CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
1407
+
1408
+ // -----
1409
+
1410
+ func.func @push_unpack_in_padded_domain_out_used (%arg0: tensor <8 x8 x4 x8 xf32 >, %arg1: tensor <?x64 xf32 >) -> tensor <?x64 xf32 > {
1411
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [4 , 8 ] into %arg1 : tensor <8 x8 x4 x8 xf32 > -> tensor <?x64 xf32 >
1412
+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%unpack : tensor <?x64 xf32 >) outs (%arg1 : tensor <?x64 xf32 >) {
1413
+ ^bb0 (%in: f32 , %out: f32 ):
1414
+ %1 = arith.addf %in , %out : f32
1415
+ linalg.yield %1 : f32
1416
+ } -> tensor <?x64 xf32 >
1417
+ return %0 : tensor <?x64 xf32 >
1418
+ }
1419
+ // CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
1420
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1421
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1422
+ // CHECK: %[[EMPTY:.+]] = tensor.empty
1423
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
1424
+ // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1425
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
1426
+ // CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
1427
+ // CHECK-SAME: into %[[ARG1]]
1428
+ // CHECK: return %[[UNPACK2]] : tensor<?x64xf32>
0 commit comments