@@ -345,3 +345,164 @@ func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_s
345
345
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1, 2] inner_tiles = [%[[ARG3]], %[[ARG1]], %[[ARG2]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
346
346
// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
347
347
// CHECK: }
348
+
349
+ // -----
350
+
351
+ func.func @linalg_transpose_tensor_pack_fold (%arg0: tensor <56 x57 x1 x64 xf32 >) -> tensor <1 x57 x56 x2 x32 xf32 > {
352
+ %0 = tensor.empty () : tensor <1 x56 x57 x64 xf32 >
353
+ %transposed = linalg.transpose
354
+ ins (%arg0 : tensor <56 x57 x1 x64 xf32 >)
355
+ outs (%0 : tensor <1 x56 x57 x64 xf32 >)
356
+ permutation = [2 , 0 , 1 , 3 ]
357
+
358
+ %1 = tensor.empty () : tensor <1 x57 x56 x2 x32 xf32 >
359
+ %pack = tensor.pack %transposed
360
+ outer_dims_perm = [0 , 2 , 1 , 3 ]
361
+ inner_dims_pos = [3 ]
362
+ inner_tiles = [32 ]
363
+ into %1 : tensor <1 x56 x57 x64 xf32 > -> tensor <1 x57 x56 x2 x32 xf32 >
364
+ return %pack : tensor <1 x57 x56 x2 x32 xf32 >
365
+ }
366
+ //CHECK-LABEL: func @linalg_transpose_tensor_pack_fold(
367
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
368
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
369
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
370
+ // CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
371
+ // CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
372
+ // CHECK-SAME: into %[[INIT]]
373
+ // CHECK: return %[[PACK]]
374
+
375
+ // -----
376
+
377
+ func.func @linalg_transpose_tensor_pack_fold_with_padding (%arg0: tensor <56 x57 x1 x55 xf32 >, %padding: f32 ) -> tensor <1 x57 x56 x2 x32 xf32 > {
378
+ %0 = tensor.empty () : tensor <1 x56 x57 x55 xf32 >
379
+ %transpose = linalg.transpose
380
+ ins (%arg0 : tensor <56 x57 x1 x55 xf32 >)
381
+ outs (%0 : tensor <1 x56 x57 x55 xf32 >)
382
+ permutation = [2 , 0 , 1 , 3 ]
383
+
384
+ %1 = tensor.empty () : tensor <1 x57 x56 x2 x32 xf32 >
385
+ %pack = tensor.pack %transpose padding_value (%padding : f32 )
386
+ outer_dims_perm = [0 , 2 , 1 , 3 ]
387
+ inner_dims_pos = [3 ]
388
+ inner_tiles = [32 ]
389
+ into %1 : tensor <1 x56 x57 x55 xf32 > -> tensor <1 x57 x56 x2 x32 xf32 >
390
+ return %pack : tensor <1 x57 x56 x2 x32 xf32 >
391
+ }
392
+ //CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_with_padding(
393
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32)
394
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
395
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32)
396
+ // CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
397
+ // CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
398
+ // CHECK-SAME: into %[[INIT]]
399
+ // CHECK: return %[[PACK]]
400
+
401
+ // -----
402
+
403
+ func.func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm (%arg0: tensor <56 x57 x1 x64 xf32 >) -> tensor <1 x56 x57 x2 x32 xf32 > {
404
+ %0 = tensor.empty () : tensor <1 x56 x57 x64 xf32 >
405
+ %transposed = linalg.transpose
406
+ ins (%arg0 : tensor <56 x57 x1 x64 xf32 >)
407
+ outs (%0 : tensor <1 x56 x57 x64 xf32 >)
408
+ permutation = [2 , 0 , 1 , 3 ]
409
+
410
+ %1 = tensor.empty () : tensor <1 x56 x57 x2 x32 xf32 >
411
+ %pack = tensor.pack %transposed
412
+ inner_dims_pos = [3 ]
413
+ inner_tiles = [32 ]
414
+ into %1 : tensor <1 x56 x57 x64 xf32 > -> tensor <1 x56 x57 x2 x32 xf32 >
415
+ return %pack : tensor <1 x56 x57 x2 x32 xf32 >
416
+ }
417
+ //CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(
418
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
419
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x56x57x2x32xf32>
420
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
421
+ // CHECK-SAME: outer_dims_perm = [2, 0, 1, 3]
422
+ // CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
423
+ // CHECK-SAME: into %[[INIT]]
424
+ // CHECK: return %[[PACK]]
425
+
426
+ // -----
427
+
428
+ func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change (%arg0: tensor <25 x30 x35 x40 xf32 >, %transpose_dest: tensor <35 x40 x25 x30 xf32 >, %pack_dest: tensor <3 x35 x5 x8 x5 x10 x5 xf32 >) -> tensor <3 x35 x5 x8 x5 x10 x5 xf32 > {
429
+ %transposed = linalg.transpose
430
+ ins (%arg0 : tensor <25 x30 x35 x40 xf32 >)
431
+ outs (%transpose_dest : tensor <35 x40 x25 x30 xf32 >)
432
+ permutation = [2 , 3 , 0 , 1 ]
433
+
434
+ %pack = tensor.pack %transposed
435
+ outer_dims_perm = [3 , 0 , 2 , 1 ]
436
+ inner_dims_pos = [1 , 3 , 2 ]
437
+ inner_tiles = [5 , 10 , 5 ]
438
+ into %pack_dest : tensor <35 x40 x25 x30 xf32 > -> tensor <3 x35 x5 x8 x5 x10 x5 xf32 >
439
+ return %pack : tensor <3 x35 x5 x8 x5 x10 x5 xf32 >
440
+ }
441
+ //CHECK-LABEL: func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(
442
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<25x30x35x40xf32>,
443
+ // CHECK-SAME: %[[ARG1:.+]]: tensor<35x40x25x30xf32>,
444
+ // CHECK-SAME: %[[ARG2:.+]]: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> {
445
+ // CHECK: %[[VAL0:.+]] = tensor.empty() : tensor<3x35x5x8x5x10x5xf32>
446
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
447
+ // CHECK-SAME: outer_dims_perm = [1, 2, 0, 3]
448
+ // CHECK-SAME: inner_dims_pos = [3, 1, 0]
449
+ // CHECK-SAME: inner_tiles = [5, 10, 5]
450
+ // CHECK-SAME: into %[[VAL0]]
451
+ // CHECK: return %[[PACK]]
452
+
453
+ // -----
454
+
455
+ func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes (%arg0: tensor <?x?x?x?xf32 >, %transpose_dest: tensor <?x?x?x?xf32 >, %pack_dest: tensor <?x?x?x?x?x?x?xf32 >, %tile_p : index , %tile_q : index , %tile_r : index ) -> tensor <?x?x?x?x?x?x?xf32 > {
456
+ %transposed = linalg.transpose
457
+ ins (%arg0 : tensor <?x?x?x?xf32 >)
458
+ outs (%transpose_dest : tensor <?x?x?x?xf32 >)
459
+ permutation = [2 , 3 , 0 , 1 ]
460
+
461
+ %pack = tensor.pack %transposed
462
+ outer_dims_perm = [3 , 0 , 2 , 1 ]
463
+ inner_dims_pos = [1 , 3 , 2 ]
464
+ inner_tiles = [%tile_p , %tile_q , %tile_r ]
465
+ into %pack_dest : tensor <?x?x?x?xf32 > -> tensor <?x?x?x?x?x?x?xf32 >
466
+ return %pack : tensor <?x?x?x?x?x?x?xf32 >
467
+ }
468
+ // CHECK: #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
469
+ //CHECK-LABEL: func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
470
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>,
471
+ // CHECK-SAME: %[[ARG2:.+]]: tensor<?x?x?x?x?x?x?xf32>, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index) -> tensor<?x?x?x?x?x?x?xf32> {
472
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
473
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
474
+ // CHECK: %[[C2:.+]] = arith.constant 2 : index
475
+ // CHECK: %[[C3:.+]] = arith.constant 3 : index
476
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
477
+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
478
+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
479
+ // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?xf32>
480
+ // CHECK: %[[VAL0:.+]] = affine.apply #[[map:.+]]()[%[[DIM2]], %[[ARG3]]]
481
+ // CHECK: %[[VAL1:.+]] = affine.apply #[[map:.+]]()[%[[DIM0]], %[[ARG4]]]
482
+ // CHECK: %[[VAL2:.+]] = affine.apply #[[map:.+]]()[%[[DIM]], %[[ARG5]]]
483
+ // CHECK: %[[VAL3:.+]] = tensor.empty(%[[VAL1]], %[[DIM1]], %[[VAL2]], %[[VAL0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?x?x?x?x?xf32>
484
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [1, 2, 0, 3] inner_dims_pos = [3, 1, 0] inner_tiles = [%[[ARG3]], %[[ARG4]], %[[ARG5]]] into %[[VAL3]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
485
+ // CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
486
+
487
+ // -----
488
+
489
+ func.func @linalg_transpose_tensor_cast_tensor_pack_fold (%arg0: tensor <56 x57 x1 x64 xf32 >) -> tensor <1 x57 x56 x2 x32 xf32 > {
490
+ %0 = tensor.empty () : tensor <1 x56 x57 x64 xf32 >
491
+ %transposed = linalg.transpose
492
+ ins (%arg0 : tensor <56 x57 x1 x64 xf32 >)
493
+ outs (%0 : tensor <1 x56 x57 x64 xf32 >)
494
+ permutation = [2 , 0 , 1 , 3 ]
495
+
496
+ %transposed_cast = tensor.cast %transposed : tensor <1 x56 x57 x64 xf32 > to tensor <?x56 x57 x64 xf32 >
497
+ %1 = tensor.empty () : tensor <1 x57 x56 x2 x32 xf32 >
498
+ %pack = tensor.pack %transposed_cast
499
+ outer_dims_perm = [0 , 2 , 1 , 3 ]
500
+ inner_dims_pos = [3 ]
501
+ inner_tiles = [32 ]
502
+ into %1 : tensor <?x56 x57 x64 xf32 > -> tensor <1 x57 x56 x2 x32 xf32 >
503
+ return %pack : tensor <1 x57 x56 x2 x32 xf32 >
504
+ }
505
+ //CHECK-LABEL: func @linalg_transpose_tensor_cast_tensor_pack_fold(
506
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
507
+ // CHECK: linalg.transpose
508
+ // CHECK: tensor.pack
0 commit comments