@@ -634,3 +634,57 @@ module attributes {transform.with_named_sequence} {
634
634
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
635
635
// CHECK: scf.yield %[[INSERT_SLICE]]
636
636
// CHECK: return %[[FOR_RESULT]]
637
+
638
+ // -----
639
+
640
+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>
641
+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d3 , d2 , d1 )>
642
+ module {
643
+ func.func private @tile_one_consumer_using_tile_and_fuse (%arg0: tensor <16 x128 x48 x96 xf32 >, %arg1: tensor <16 x96 x48 x128 xf32 >) -> tensor <16 x96 x48 x128 xf32 > {
644
+ %0 = linalg.generic {index ing_maps = [#map , #map1 ], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 : tensor <16 x128 x48 x96 xf32 >) outs (%arg1 : tensor <16 x96 x48 x128 xf32 >) {
645
+ ^bb0 (%in: f32 , %out: f32 ):
646
+ linalg.yield %in : f32
647
+ } -> tensor <16 x96 x48 x128 xf32 >
648
+ return %0 : tensor <16 x96 x48 x128 xf32 >
649
+ }
650
+ }
651
+ module attributes {transform.with_named_sequence } {
652
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
653
+ %generic = transform.structured.match ops {[" linalg.generic" ]} in %arg1
654
+ : (!transform.any_op ) -> !transform.any_op
655
+ %a , %loops:4 = transform.structured.fuse %generic {tile_sizes = [1 , 16 , 16 , 16 ], tile_interchange = [0 , 1 , 2 , 3 ], apply_cleanup = false }
656
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
657
+ transform.yield
658
+ }
659
+ }
660
+
661
+ // CHECK: func.func private @tile_one_consumer_using_tile_and_fuse(%[[VAL_0:.*]]: tensor<16x128x48x96xf32>, %[[VAL_1:.*]]: tensor<16x96x48x128xf32>) -> tensor<16x96x48x128xf32> {
662
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
663
+ // CHECK: %[[VAL_3:.*]] = arith.constant 16 : index
664
+ // CHECK: %[[VAL_4:.*]] = arith.constant 128 : index
665
+ // CHECK: %[[VAL_5:.*]] = arith.constant 48 : index
666
+ // CHECK: %[[VAL_6:.*]] = arith.constant 96 : index
667
+ // CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
668
+ // CHECK: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_7]] iter_args(%[[VAL_10:.*]] = %[[VAL_1]]) -> (tensor<16x96x48x128xf32>) {
669
+ // CHECK: %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (tensor<16x96x48x128xf32>) {
670
+ // CHECK: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<16x96x48x128xf32>) {
671
+ // CHECK: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<16x96x48x128xf32>) {
672
+ // CHECK: %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_12]], %[[VAL_15]], %[[VAL_18]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<16x128x48x96xf32> to tensor<1x16x16x16xf32>
673
+ // CHECK: %[[VAL_21:.*]] = tensor.extract_slice %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_18]], %[[VAL_15]], %[[VAL_12]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<16x96x48x128xf32> to tensor<1x16x16x16xf32>
674
+ // CHECK: %[[VAL_22:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_20]] : tensor<1x16x16x16xf32>) outs(%[[VAL_21]] : tensor<1x16x16x16xf32>) {
675
+ // CHECK: ^bb0(%[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
676
+ // CHECK: linalg.yield %[[VAL_23]] : f32
677
+ // CHECK: } -> tensor<1x16x16x16xf32>
678
+ // CHECK: %[[VAL_25:.*]] = tensor.insert_slice %[[VAL_26:.*]] into %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_18]], %[[VAL_15]], %[[VAL_12]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<1x16x16x16xf32> into tensor<16x96x48x128xf32>
679
+ // CHECK: scf.yield %[[VAL_25]] : tensor<16x96x48x128xf32>
680
+ // CHECK: }
681
+ // CHECK: scf.yield %[[VAL_27:.*]] : tensor<16x96x48x128xf32>
682
+ // CHECK: }
683
+ // CHECK: scf.yield %[[VAL_28:.*]] : tensor<16x96x48x128xf32>
684
+ // CHECK: }
685
+ // CHECK: scf.yield %[[VAL_29:.*]] : tensor<16x96x48x128xf32>
686
+ // CHECK: }
687
+ // CHECK: return %[[VAL_30:.*]] : tensor<16x96x48x128xf32>
688
+ // CHECK: }
689
+ // CHECK: }
690
+
0 commit comments