@@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
458
458
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
459
459
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
460
460
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
461
- // CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
462
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
461
+ // CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
462
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
463
463
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
464
464
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
465
- // CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
466
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
465
+ // CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
466
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
467
467
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
468
468
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
469
- // CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
470
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
469
+ // CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
470
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
471
471
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
472
472
// CHECK: %[[RES:.+]] = linalg.generic
473
473
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
474
474
// CHECK-SAME: ins(%[[ARG0_PACK]]
475
475
// CHECK-SAME: outs(%[[ARG1_PACK]]
476
- // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
477
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
476
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
477
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
478
478
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
479
479
480
480
// -----
@@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
537
537
// CHECK-LABEL: func.func @forward_tensor_empty
538
538
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
539
539
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
540
- // CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
541
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
540
+ // CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
541
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
542
542
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
543
543
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
544
544
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
545
- // CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
546
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
545
+ // CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
546
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
547
547
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
548
548
// CHECK: %[[RES:.+]] = linalg.generic
549
549
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
550
550
// CHECK-SAME: ins(%[[PACKED_ARG0]]
551
551
// CHECK-SAME: outs(%[[DEST]]
552
552
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
553
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
553
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
554
554
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
555
555
556
556
// -----
@@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens
571
571
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
572
572
// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
573
573
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
574
- // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
575
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
574
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
575
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
576
576
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
577
577
578
578
// -----
@@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
614
614
// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
615
615
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
616
616
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
617
- // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
618
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
617
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
618
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
619
619
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
620
620
// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
621
621
@@ -687,6 +687,29 @@ func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x
687
687
688
688
// -----
689
689
690
+ func.func @multi_use_pad_pack_propagation (%arg0: tensor <1 x64 x56 x56 xf32 >) -> (tensor <1 x64 x58 x58 xf32 >, tensor <1 x2 x58 x58 x32 xf32 >) {
691
+ %cst = arith.constant 0.000000e+00 : f32
692
+ %padded = tensor.pad %arg0 low [0 , 0 , 1 , 1 ] high [0 , 0 , 1 , 1 ] {
693
+ ^bb0 (%arg3: index , %arg4: index , %arg5: index , %arg6: index ):
694
+ tensor.yield %cst : f32
695
+ } : tensor <1 x64 x56 x56 xf32 > to tensor <1 x64 x58 x58 xf32 >
696
+ %0 = tensor.empty () : tensor <1 x2 x58 x58 x32 xf32 >
697
+ %1 = tensor.pack %padded inner_dims_pos = [1 ] inner_tiles = [32 ] into %0 : tensor <1 x64 x58 x58 xf32 > -> tensor <1 x2 x58 x58 x32 xf32 >
698
+ return %padded , %1 : tensor <1 x64 x58 x58 xf32 >, tensor <1 x2 x58 x58 x32 xf32 >
699
+ }
700
+
701
+ // CHECK-LABEL: func.func @multi_use_pad_pack_propagation(
702
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
703
+ // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
704
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
705
+ // CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
706
+ // CHECK-SAME: into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
707
+ // CHECK: %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
708
+ // CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
709
+ // CHECK: return %[[UNPACKED]], %[[PADDED]]
710
+
711
+ // -----
712
+
690
713
#map0 = affine_map <(d0 , d1 ) -> (d0 , d1 )>
691
714
func.func @would_break_dominance (%arg0: tensor <128 x256 xi32 >) -> tensor <4 x16 x16 x32 xi32 >{
692
715
%init = tensor.empty () : tensor <128 x256 xi32 >
@@ -713,7 +736,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
713
736
// CHECK-SAME: outs(%[[EMPTY]]
714
737
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
715
738
// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]]
716
- // CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
739
+ // CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
717
740
// CHECK-SAME: into %[[ALLOC]]
718
741
719
742
// -----
@@ -760,19 +783,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
760
783
761
784
// CHECK-LABEL: func.func @unpack_empty_inner_dims
762
785
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack
763
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
764
- // CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
765
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
786
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
787
+ // CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
788
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
766
789
// CHECK: %[[RES:.+]] = linalg.generic
767
790
// CHECK-SAME: ins(%[[PACKED_ARG0]]
768
791
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
769
- // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
792
+ // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
770
793
771
794
// -----
772
795
773
796
#map0 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
774
797
#map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
775
- func.func @reduction_pack_transpose_inner_dims (%arg0: tensor <128 x256 x32 xi32 >,
798
+ func.func @reduction_pack_transpose_inner_dims (%arg0: tensor <128 x256 x32 xi32 >,
776
799
%arg1: tensor <128 x256 xi32 >) -> tensor <4 x16 x16 x32 xi32 >{
777
800
%elem = linalg.generic {index ing_maps = [#map0 , #map1 ], iterator_types = [" parallel" , " parallel" , " reduction" ]}
778
801
ins (%arg0 : tensor <128 x256 x32 xi32 >)
@@ -810,7 +833,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
810
833
811
834
// -----
812
835
813
- func.func @reduction_pack_with_outer_dims (%arg0: tensor <100 x128 x200 x256 xi32 >, %arg1: tensor <100 xi32 >,
836
+ func.func @reduction_pack_with_outer_dims (%arg0: tensor <100 x128 x200 x256 xi32 >, %arg1: tensor <100 xi32 >,
814
837
%arg2: tensor <128 xi32 >, %init_reduction: tensor <100 x128 x256 xi32 >) -> tensor <4 x16 x100 x16 x32 xi32 >
815
838
{
816
839
%reduction = linalg.generic {
@@ -867,7 +890,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
867
890
#map0 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 * 2 + d4 , d3 * 2 + d5 )>
868
891
#map1 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d4 , d5 )>
869
892
#map2 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d1 , d2 , d3 )>
870
- func.func @unpack_different_destination_shape (%arg0: tensor <1 x1 x1080 x1920 x16 xi32 >,
893
+ func.func @unpack_different_destination_shape (%arg0: tensor <1 x1 x1080 x1920 x16 xi32 >,
871
894
%filter: tensor <2 x2 xi32 >) -> tensor <16 x540 x960 xi32 >{
872
895
%init = tensor.empty () : tensor <16 x540 x960 xi32 >
873
896
%empty = tensor.empty () : tensor <1 x16 x1080 x1920 xi32 >
0 commit comments