@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
550
550
551
551
// -----
552
552
553
+ func.func @linalg_transpose_tensor_unpack_fold_partial_tile (%arg0: tensor <1 x1 x4 x16 xi32 >) -> tensor <15 x3 xi32 > {
554
+ %0 = tensor.empty () : tensor <1 x1 x16 x4 xi32 >
555
+ %transposed = linalg.transpose ins (%arg0 : tensor <1 x1 x4 x16 xi32 >)
556
+ outs (%0 : tensor <1 x1 x16 x4 xi32 >)
557
+ permutation = [1 , 0 , 3 , 2 ]
558
+ %1 = tensor.empty () : tensor <15 x3 xi32 >
559
+ %unpack = tensor.unpack %transposed
560
+ outer_dims_perm = [0 , 1 ]
561
+ inner_dims_pos = [0 , 1 ]
562
+ inner_tiles = [16 , 4 ] into
563
+ %1 : tensor <1 x1 x16 x4 xi32 > -> tensor <15 x3 xi32 >
564
+ return %unpack : tensor <15 x3 xi32 >
565
+ }
566
+ //CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
567
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
568
+ // CHECK: %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
569
+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
570
+ // CHECK-SAME: outer_dims_perm = [1, 0]
571
+ // CHECK-SAME: inner_dims_pos = [1, 0]
572
+ // CHECK-SAME: inner_tiles = [4, 16]
573
+ // CHECK-SAME: into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
574
+ // CHECK: return %[[UNPACK]] : tensor<15x3xi32>
575
+ // CHECK: }
576
+
577
+ // -----
578
+
553
579
func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes (%arg0: tensor <?x?x?x?xf32 >, %transpose_dest: tensor <?x?x?x?xf32 >, %unpack_dest: tensor <?x?xf32 >, %tile_p : index , %tile_q : index ) -> tensor <?x?xf32 > {
554
580
%transposed = linalg.transpose
555
581
ins (%arg0 : tensor <?x?x?x?xf32 >)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
563
589
into %unpack_dest : tensor <?x?x?x?xf32 > -> tensor <?x?xf32 >
564
590
return %unpack : tensor <?x?xf32 >
565
591
}
566
- // CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
567
592
// CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
568
593
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
569
594
// CHECK-SAME: %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
570
595
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index
571
596
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index
572
- // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
573
- // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
574
- // CHECK-DAG: %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
575
- // CHECK-DAG: %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
576
- // CHECK: %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
597
+ // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
598
+ // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
599
+ // CHECK: %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
577
600
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
578
601
// CHECK-SAME: outer_dims_perm = [0, 1]
579
602
// CHECK-SAME: inner_dims_pos = [1, 0]
0 commit comments