1
- // RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
1
+ // RUN: mlir-opt %s -test-transform-dialect-interpreter -cse - -split-input-file | FileCheck %s
2
2
3
3
// CHECK-LABEL: func.func @pack(
4
4
func.func @pack (%arg0: tensor <129 x47 x16 x16 xf32 >, %arg1: tensor <17 x2 x16 x16 x32 x8 xf32 >) -> tensor <17 x2 x16 x16 x32 x8 xf32 > {
5
5
%cst_0 = arith.constant 0.0 : f32
6
6
7
7
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
8
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
9
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
8
+ // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
10
9
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
11
10
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
12
11
// CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32>
@@ -33,8 +32,7 @@ transform.sequence failures(propagate) {
33
32
func.func @pack (%arg0: tensor <128 x8 xf32 >, %arg1: tensor <8 x8 x16 x1 xf32 >) -> tensor <8 x8 x16 x1 xf32 > {
34
33
35
34
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
36
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
37
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]]]
35
+ // CHECK: tensor.pad {{.*}} low[0, 0]
38
36
// CHECK: : tensor<128x8xf32> to tensor<128x8xf32>
39
37
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]]
40
38
// CHECK-SAME: : tensor<128x8xf32> into tensor<8x16x8x1xf32>
@@ -64,8 +62,7 @@ func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x13
64
62
%cst_0 = arith.constant 0.0 : f32
65
63
66
64
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
67
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
68
- // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
65
+ // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
69
66
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
70
67
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
71
68
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
@@ -100,8 +97,7 @@ transform.sequence failures(propagate) {
100
97
func.func @pack_not_a_pad (%arg0: tensor <129 x47 x16 x16 xf32 >, %arg1: tensor <1 x1 x16 x16 x136 x64 xf32 >) -> tensor <1 x1 x16 x16 x136 x64 xf32 > {
101
98
%cst_0 = arith.constant 0.0 : f32
102
99
103
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
104
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
100
+ // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
105
101
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
106
102
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
107
103
// CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32>
@@ -190,8 +186,7 @@ transform.sequence failures(propagate) {
190
186
func.func @pack_with_outer_dims_perm (%src: tensor <100 x200 x128 x256 xi32 >,
191
187
%dest: tensor <200 x4 x16 x100 x16 x32 xi32 >)
192
188
-> tensor <200 x4 x16 x100 x16 x32 xi32 > {
193
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
194
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
189
+ // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
195
190
// CHECK: : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32>
196
191
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
197
192
// CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
@@ -221,8 +216,7 @@ transform.sequence failures(propagate) {
221
216
func.func @pack_with_pad_and_outer_dims_perm (%src: tensor <100 x200 x127 x255 xi32 >,
222
217
%dest: tensor <200 x4 x16 x100 x16 x32 xi32 >)
223
218
-> tensor <200 x4 x16 x100 x16 x32 xi32 > {
224
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
225
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
219
+ // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
226
220
// CHECK: : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32>
227
221
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
228
222
// CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
@@ -250,13 +244,64 @@ transform.sequence failures(propagate) {
250
244
251
245
// -----
252
246
247
+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)>
248
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 32) * 32)>
249
+ // CHECK: func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(
250
+ // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
251
+ func.func @dynamic_pack_pad_transpose_inner_and_outer_dims (%source: tensor <?x?xf32 >) -> tensor <?x?x16 x32 xf32 > {
252
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
253
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
254
+ // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
255
+ // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
256
+ // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[SRC]], %[[C0]]
257
+ // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[SRC]], %[[C1]]
258
+ // CHECK-DAG: %[[OUT_D0:.+]] = arith.ceildivui %[[D1]], %[[C16]] : index
259
+ // CHECK-DAG: %[[OUT_D1:.+]] = arith.ceildivui %[[D0]], %[[C32]] : index
260
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor<?x?x16x32xf32>
261
+ // CHECK-DAG: %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
262
+ // CHECK-DAG: %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[D0]]]
263
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[%[[H0]], %[[H1]]]
264
+ // CHECK: : tensor<?x?xf32> to tensor<?x?xf32>
265
+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]]
266
+ // CHECK-SAME: : tensor<?x?xf32> into tensor<?x32x?x16xf32>
267
+ // CHECK: %[[TRANSP:.+]] = linalg.transpose
268
+ // CHECK-SAME: ins(%[[EXPAND]] : tensor<?x32x?x16xf32>)
269
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x16x32xf32>)
270
+ // CHECK-SAME: permutation = [2, 0, 3, 1]
271
+ // CHECK: return %[[TRANSP]]
272
+ %c0 = arith.constant 0 : index
273
+ %c1 = arith.constant 1 : index
274
+ %d0 = tensor.dim %source , %c0 : tensor <?x?xf32 >
275
+ %d1 = tensor.dim %source , %c1 : tensor <?x?xf32 >
276
+ %padding_value = arith.constant 0.0 : f32
277
+
278
+ %c16 = arith.constant 16 : index
279
+ %c32 = arith.constant 32 : index
280
+ %tiled_d0 = arith.ceildivui %d0 , %c32 : index
281
+ %tiled_d1 = arith.ceildivui %d1 , %c16 : index
282
+ %init_pack = tensor.empty (%tiled_d1 , %tiled_d0 ) : tensor <?x?x16 x32 xf32 >
283
+ %pack = tensor.pack %source padding_value (%padding_value : f32 )
284
+ outer_dims_perm = [1 , 0 ] inner_dims_pos = [1 , 0 ] inner_tiles = [16 , 32 ] into %init_pack
285
+ : tensor <?x?xf32 > -> tensor <?x?x16 x32 xf32 >
286
+ return %pack : tensor <?x?x16 x32 xf32 >
287
+ }
288
+
289
+ transform.sequence failures (propagate ) {
290
+ ^bb1 (%module_op: !pdl.operation ):
291
+ %pack = transform.structured.match ops {[" tensor.pack" ]} in %module_op
292
+ : (!pdl.operation ) -> !transform.op <" tensor.pack" >
293
+ transform.structured.lower_pack %pack : (!transform.op <" tensor.pack" >)
294
+ -> (!transform.op <" tensor.pad" >, !transform.op <" tensor.expand_shape" >, !transform.op <" linalg.transpose" >)
295
+ }
296
+
297
+ // -----
298
+
253
299
// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
254
300
func.func @pack_as_pad_with_outer_dims_perm (%arg0: tensor <129 x47 x16 x16 xf32 >, %arg1: tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >) -> tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 > {
255
301
%cst_0 = arith.constant 0.0 : f32
256
302
257
303
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
258
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
259
- // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
304
+ // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
260
305
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
261
306
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
262
307
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
0 commit comments