@@ -464,6 +464,129 @@ module attributes {transform.with_named_sequence} {
464
464
465
465
// -----
466
466
467
+ // Check that we can lower unpack with dynamic dimensions in the input and destination.
468
+ // CHECK-LABEL: func.func @unpack_with_dynamic_input_dest(
469
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x8x16xf32>, %[[ARG1:.*]]: tensor<?x?xf32>)
470
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
471
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
472
+ // CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
473
+ // CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
474
+ // CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM01]]) : tensor<?x8x?x16xf32>
475
+ // CHECK: %[[TRAN:.*]] = linalg.transpose
476
+ // CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x8x16xf32>)
477
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x?x16xf32>)
478
+ // CHECK-SAME: permutation = [0, 2, 1, 3]
479
+ // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
480
+ // CHECK-SAME: : tensor<?x8x?x16xf32> into tensor<?x?xf32>
481
+ // CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
482
+ // CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
483
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
484
+ // CHECK-SAME: : tensor<?x?xf32> to tensor<?x?xf32>
485
+ // CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
486
+ // CHECK-SAME: outs(%[[ARG1]] : tensor<?x?xf32>)
487
+ func.func @unpack_with_dynamic_input_dest (%arg0: tensor <?x?x8 x16 xf32 >, %arg1: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
488
+ %unpack = tensor.unpack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 16 ] into %arg1 : tensor <?x?x8 x16 xf32 > -> tensor <?x?xf32 >
489
+ return %unpack : tensor <?x?xf32 >
490
+ }
491
+
492
+ module attributes {transform.with_named_sequence } {
493
+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
494
+ %unpack = transform.structured.match ops {[" tensor.unpack" ]} in %module_op
495
+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
496
+ transform.structured.lower_unpack %unpack : (!transform.op <" tensor.unpack" >)
497
+ -> (!transform.op <" tensor.empty" >,
498
+ !transform.op <" linalg.transpose" >,
499
+ !transform.op <" tensor.collapse_shape" >,
500
+ !transform.op <" tensor.extract_slice" >)
501
+ transform.yield
502
+ }
503
+ }
504
+
505
+ // -----
506
+
507
+ // Check that we can lower unpack with dynamic dimensions in the input, destination, inner_tiles.
508
+ // CHECK-LABEL: func.func @unpack_fully_dynamic(
509
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
510
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
511
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
512
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
513
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
514
+ // CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
515
+ // CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
516
+ // CHECK-DAG: %[[DIM02:.*]] = tensor.dim %[[ARG0]], %[[C2]]
517
+ // CHECK-DAG: %[[DIM03:.*]] = tensor.dim %[[ARG0]], %[[C3]]
518
+ // CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM02]], %[[DIM01]], %[[DIM03]]) : tensor<?x?x?x?xf32>
519
+ // CHECK: %[[TRAN:.*]] = linalg.transpose
520
+ // CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
521
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x?x?xf32>)
522
+ // CHECK-SAME: permutation = [0, 2, 1, 3]
523
+ // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
524
+ // CHECK-SAME: : tensor<?x?x?x?xf32> into tensor<?x?xf32>
525
+ // CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
526
+ // CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
527
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
528
+ // CHECK-SAME: : tensor<?x?xf32> to tensor<?x?xf32>
529
+ // CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
530
+ // CHECK-SAME: outs(%[[ARG1]] : tensor<?x?xf32>)
531
+ func.func @unpack_fully_dynamic (%source: tensor <?x?x?x?xf32 >, %dest: tensor <?x?xf32 >, %tile_n : index , %tile_m : index ) -> tensor <?x?xf32 > {
532
+ %0 = tensor.unpack %source inner_dims_pos = [0 , 1 ] inner_tiles = [%tile_n , %tile_m ] into %dest : tensor <?x?x?x?xf32 > -> tensor <?x?xf32 >
533
+ return %0 : tensor <?x?xf32 >
534
+ }
535
+ module attributes {transform.with_named_sequence } {
536
+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
537
+ %unpack = transform.structured.match ops {[" tensor.unpack" ]} in %module_op
538
+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
539
+ transform.structured.lower_unpack %unpack : (!transform.op <" tensor.unpack" >)
540
+ -> (!transform.op <" tensor.empty" >,
541
+ !transform.op <" linalg.transpose" >,
542
+ !transform.op <" tensor.collapse_shape" >,
543
+ !transform.op <" tensor.extract_slice" >)
544
+ transform.yield
545
+ }
546
+ }
547
+
548
+ // -----
549
+
550
+ // Check that we can lower unpack "as unpad" with dynamic dims.
551
+ // CHECK-LABEL: func.func @unpack_as_pad_dynamic(
552
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x1x?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>
553
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
554
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
555
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
556
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
557
+ // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
558
+ // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
559
+ // CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
560
+ // CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
561
+ // CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
562
+ // offsets.
563
+ // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
564
+ // sizes.
565
+ // CHECK-SAME: [1, 1, 1, 1, %[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
566
+ // strides multiplers.
567
+ // CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1]
568
+ // CHECK-SAME: : tensor<1x1x1x1x?x?x?x?xf32> to tensor<?x?x?x?xf32>
569
+ func.func @unpack_as_pad_dynamic (%arg0: tensor <1 x1 x1 x1 x?x?x?x?xf32 >, %arg1: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 > {
570
+ %pack = tensor.unpack %arg0 inner_dims_pos = [0 , 1 , 2 , 3 ] inner_tiles = [136 , 64 , 16 , 16 ] into %arg1
571
+ : tensor <1 x1 x1 x1 x?x?x?x?xf32 > -> tensor <?x?x?x?xf32 >
572
+ return %pack : tensor <?x?x?x?xf32 >
573
+ }
574
+
575
+ module attributes {transform.with_named_sequence } {
576
+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
577
+ %unpack = transform.structured.match ops {[" tensor.unpack" ]} in %module_op
578
+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
579
+ transform.structured.lower_unpack %unpack : (!transform.op <" tensor.unpack" >)
580
+ -> (!transform.op <" tensor.empty" >,
581
+ !transform.op <" linalg.transpose" >,
582
+ !transform.op <" tensor.collapse_shape" >,
583
+ !transform.op <" tensor.extract_slice" >)
584
+ transform.yield
585
+ }
586
+ }
587
+
588
+ // -----
589
+
467
590
// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
468
591
func.func @diagnostic_unpack (%arg0: tensor <32 x64 xf32 >, %arg1: tensor <2 x4 x32 x8 xf32 >) -> tensor <32 x64 xf32 > {
469
592
// expected-note @below {{target payload op}}
0 commit comments