@@ -676,3 +676,65 @@ module attributes {transform.with_named_sequence} {
676
676
// CHECK: }
677
677
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
678
678
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
679
+
680
+ // -----
681
+
682
+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
683
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d2 )>
684
+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
685
+ module {
686
+ func.func @fuse_with_tilable_consumer_with_projected_permutations (%arg0: tensor <256 x256 xf32 >, %arg1: tensor <256 x256 xf32 >, %arg2: tensor <24 xf32 >) -> tensor <256 x256 x24 xf32 > {
687
+ %c0 = arith.constant 0 : index
688
+ %c64 = arith.constant 64 : index
689
+ %c256 = arith.constant 256 : index
690
+ %0 = tensor.empty () : tensor <256 x256 xf32 >
691
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args (%arg4 = %0 ) -> (tensor <256 x256 xf32 >) {
692
+ %extracted_slice = tensor.extract_slice %arg4 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
693
+ %extracted_slice_0 = tensor.extract_slice %arg0 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
694
+ %extracted_slice_1 = tensor.extract_slice %arg1 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
695
+ %4 = linalg.add ins (%extracted_slice_0 , %extracted_slice_1 : tensor <64 x256 xf32 >, tensor <64 x256 xf32 >) outs (%extracted_slice : tensor <64 x256 xf32 >) -> tensor <64 x256 xf32 >
696
+ %inserted_slice = tensor.insert_slice %4 into %arg4 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <64 x256 xf32 > into tensor <256 x256 xf32 >
697
+ scf.yield %inserted_slice : tensor <256 x256 xf32 >
698
+ }
699
+ %2 = tensor.empty () : tensor <256 x256 x24 xf32 >
700
+ %3 = linalg.generic {index ing_maps = [#map , #map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " parallel" ]} ins (%1 , %arg2 : tensor <256 x256 xf32 >, tensor <24 xf32 >) outs (%2 : tensor <256 x256 x24 xf32 >) {
701
+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
702
+ %4 = arith.addf %in , %in_0 : f32
703
+ linalg.yield %4 : f32
704
+ } -> tensor <256 x256 x24 xf32 >
705
+ return %3 : tensor <256 x256 x24 xf32 >
706
+ }
707
+ }
708
+
709
+ // CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
710
+ // CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
711
+ // CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
712
+ // CHECK: %[[VAL_5:.*]] = arith.constant 256 : index
713
+ // CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
714
+ // CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
715
+ // CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
716
+ // CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
717
+ // CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
718
+ // CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
719
+ // CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
720
+ // CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
721
+ // CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
722
+ // CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
723
+ // CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
724
+ // CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
725
+ // CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
726
+ // CHECK: linalg.yield %[[VAL_23]] : f32
727
+ // CHECK: } -> tensor<64x256x24xf32>
728
+ // CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
729
+ // CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
730
+ // CHECK: }
731
+
732
+ module attributes {transform.with_named_sequence } {
733
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
734
+ %slice_op = transform.structured.match ops {[" tensor.insert_slice" ]} in %arg1
735
+ : (!transform.any_op ) -> !transform.any_op
736
+ %a , %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
737
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
738
+ transform.yield
739
+ }
740
+ }
0 commit comments