@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
96
96
97
97
// -----
98
98
99
+ // This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
100
+ // be lowered to insert_slice.
101
+ // CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice(
102
+ func.func @pack_as_pad_disabled_insert_slice (%arg0: tensor <129 x47 x16 x16 xf32 >, %arg1: tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >) -> tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 > {
103
+ %cst_0 = arith.constant 0.0 : f32
104
+ // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
105
+ // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
106
+ // CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]]
107
+ // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
108
+ // CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
109
+ // CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
110
+ %pack = tensor.pack %arg0 padding_value (%cst_0 : f32 ) inner_dims_pos = [0 , 1 , 2 , 3 ] inner_tiles = [136 , 64 , 16 , 16 ] into %arg1
111
+ : tensor <129 x47 x16 x16 xf32 > -> tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >
112
+ return %pack : tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >
113
+ }
114
+
115
+ module attributes {transform.with_named_sequence } {
116
+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
117
+ %pack = transform.structured.match ops {[" tensor.pack" ]} in %module_op
118
+ : (!transform.any_op ) -> !transform.op <" tensor.pack" >
119
+ transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false }: (!transform.op <" tensor.pack" >)
120
+ -> (!transform.op <" tensor.pad" >, !transform.op <" tensor.expand_shape" >, !transform.op <" linalg.transpose" >)
121
+ transform.yield
122
+ }
123
+ }
124
+
125
+ // -----
126
+
99
127
// Check that we don't lower the following pack as a pad.
100
128
// Although all the outer most dimensions in the resulting shape are 1s,
101
129
// some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} {
233
261
234
262
// -----
235
263
264
+ // This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
265
+ // be lowered to extract_slice.
266
+ // CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice(
267
+ func.func @unpack_as_pad_disabled_extract_slice (%arg0: tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >, %arg1: tensor <129 x47 x16 x16 xf32 >) -> tensor <129 x47 x16 x16 xf32 > {
268
+ %cst_0 = arith.constant 0.0 : f32
269
+
270
+ // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
271
+ // CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
272
+ // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
273
+ // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
274
+ // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
275
+ // CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
276
+ %pack = tensor.unpack %arg0 inner_dims_pos = [0 , 1 , 2 , 3 ] inner_tiles = [136 , 64 , 16 , 16 ] into %arg1
277
+ : tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 > -> tensor <129 x47 x16 x16 xf32 >
278
+ return %pack : tensor <129 x47 x16 x16 xf32 >
279
+ }
280
+
281
+ module attributes {transform.with_named_sequence } {
282
+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
283
+ %unpack = transform.structured.match ops {[" tensor.unpack" ]} in %module_op
284
+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
285
+ transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false }: (!transform.op <" tensor.unpack" >)
286
+ -> (!transform.op <" tensor.empty" >,
287
+ !transform.op <" linalg.transpose" >,
288
+ !transform.op <" tensor.collapse_shape" >,
289
+ !transform.op <" tensor.extract_slice" >)
290
+ transform.yield
291
+ }
292
+ }
293
+
294
+ // -----
295
+
236
296
// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
237
297
func.func @pack_with_outer_dims_perm (%src: tensor <100 x200 x128 x256 xi32 >,
238
298
%dest: tensor <200 x4 x16 x100 x16 x32 xi32 >)
@@ -572,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?x
572
632
module attributes {transform.with_named_sequence } {
573
633
transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
574
634
%unpack = transform.structured.match ops {[" tensor.unpack" ]} in %module_op
575
- : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
635
+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
576
636
transform.structured.lower_unpack %unpack : (!transform.op <" tensor.unpack" >)
577
637
-> (!transform.op <" tensor.empty" >,
578
638
!transform.op <" linalg.transpose" >,
@@ -627,9 +687,9 @@ module attributes {transform.with_named_sequence} {
627
687
// CHECK-LABEL: @unpack_with_outer_dims_perm
628
688
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
629
689
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
630
- // CHECK: %[[TRAN:.*]] = linalg.transpose
631
- // CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
632
- // CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
690
+ // CHECK: %[[TRAN:.*]] = linalg.transpose
691
+ // CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
692
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
633
693
// CHECK-SAME: permutation = [1, 3, 0, 2]
634
694
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
635
695
// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32>
@@ -638,7 +698,7 @@ module attributes {transform.with_named_sequence} {
638
698
// CHECK: linalg.copy ins(%[[SLICE]]
639
699
// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
640
700
func.func @unpack_with_outer_dims_perm (%arg0: tensor <32 x64 xf32 >, %arg1: tensor <2 x4 x32 x8 xf32 >) -> tensor <32 x64 xf32 > {
641
- %unpack = tensor.unpack %arg1 outer_dims_perm = [1 , 0 ]
701
+ %unpack = tensor.unpack %arg1 outer_dims_perm = [1 , 0 ]
642
702
inner_dims_pos = [1 , 0 ] inner_tiles = [32 , 8 ] into %arg0 : tensor <2 x4 x32 x8 xf32 > -> tensor <32 x64 xf32 >
643
703
return %unpack : tensor <32 x64 xf32 >
644
704
}
0 commit comments