@@ -48,7 +48,7 @@ module {
48
48
49
49
// linalg.fill is tileable. The op is tiled and fused.
50
50
transform.structured.fuse_into_containing_op %0 into %1
51
- : (!transform.op <" linalg.fill" >, !transform.op <" scf.forall" >) -> !transform.any_op
51
+ : (!transform.op <" linalg.fill" >, !transform.op <" scf.forall" >) -> ( !transform.any_op , !transform.any_op )
52
52
}
53
53
}
54
54
@@ -92,7 +92,7 @@ module {
92
92
93
93
// tensor.empty is not tileable. The op is cloned and fused.
94
94
transform.structured.fuse_into_containing_op %0 into %1
95
- : (!transform.op <" tensor.empty" >, !transform.op <" scf.forall" >) -> !transform.any_op
95
+ : (!transform.op <" tensor.empty" >, !transform.op <" scf.forall" >) -> ( !transform.any_op , !transform.any_op )
96
96
}
97
97
}
98
98
@@ -139,7 +139,7 @@ module {
139
139
140
140
// linalg.fill is tileable. The op is tiled and fused.
141
141
transform.structured.fuse_into_containing_op %0 into %1
142
- : (!transform.op <" linalg.fill" >, !transform.op <" scf.forall" >) -> !transform.any_op
142
+ : (!transform.op <" linalg.fill" >, !transform.op <" scf.forall" >) -> ( !transform.any_op , !transform.any_op )
143
143
}
144
144
}
145
145
@@ -188,7 +188,7 @@ module {
188
188
189
189
// linalg.fill is tileable. The op is tiled and fused.
190
190
transform.structured.fuse_into_containing_op %0 into %1
191
- : (!transform.any_op , !transform.any_op ) -> !transform.any_op
191
+ : (!transform.any_op , !transform.any_op ) -> ( !transform.any_op , !transform.any_op )
192
192
}
193
193
}
194
194
@@ -249,7 +249,7 @@ module {
249
249
250
250
// linalg.generic is tileable. The op is tiled and fused.
251
251
transform.structured.fuse_into_containing_op %0 into %1
252
- : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> !transform.any_op
252
+ : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> ( !transform.any_op , !transform.any_op )
253
253
}
254
254
}
255
255
@@ -285,7 +285,7 @@ module {
285
285
%2 = transform.merge_handles %0 , %0 : !transform.any_op
286
286
287
287
// It shouldn't be a problem to fuse this handle.
288
- transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op , !transform.any_op ) -> !transform.any_op
288
+ transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op , !transform.any_op ) -> ( !transform.any_op , !transform.any_op )
289
289
}
290
290
}
291
291
@@ -351,7 +351,7 @@ module {
351
351
352
352
// linalg.generic is tileable. The op is tiled and fused.
353
353
transform.structured.fuse_into_containing_op %0 into %1
354
- : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> !transform.any_op
354
+ : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> ( !transform.any_op , !transform.any_op )
355
355
}
356
356
}
357
357
@@ -417,7 +417,7 @@ module {
417
417
418
418
// linalg.generic is tileable. The op is tiled and fused.
419
419
transform.structured.fuse_into_containing_op %0 into %1
420
- : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> !transform.any_op
420
+ : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> ( !transform.any_op , !transform.any_op )
421
421
}
422
422
}
423
423
@@ -482,6 +482,81 @@ module {
482
482
483
483
// linalg.generic is tileable. The op is tiled and fused.
484
484
transform.structured.fuse_into_containing_op %0 into %1
485
- : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> !transform.any_op
485
+ : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> (!transform.any_op , !transform.any_op )
486
+ }
487
+ }
488
+
489
+ // -----
490
+
491
+ #map0 = affine_map <()[s0 , s1 ] -> (s0 ceildiv s1 )>
492
+ #map1 = affine_map <(d0 )[s0 ] -> (d0 * s0 )>
493
+ #map2 = affine_map <(d0 )[s0 , s1 ] -> (-(d0 * s1 ) + s0 , s1 )>
494
+ #map3 = affine_map <(d0 ) -> (d0 )>
495
+
496
+ module {
497
+ // CHECK-LABEL: func.func @fuse_tileable_using_new_handle
498
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
499
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
500
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
501
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
502
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
503
+ func.func @fuse_tileable_using_new_handle (%idx: index , %in: tensor <?xf32 >, %out_1: tensor <?xf32 >, %out_2: tensor <?xf32 >, %out_3: tensor <?xf32 >)
504
+ -> (tensor <?xf32 >, tensor <?xf32 >) {
505
+ %cst = arith.constant 4.200000e+01 : f32
506
+ %c0 = arith.constant 0 : index
507
+
508
+ %0 = linalg.generic {
509
+ indexing_maps = [#map3 , #map3 ], iterator_types = [" parallel" ]
510
+ } ins (%in : tensor <?xf32 >) outs (%out_1 : tensor <?xf32 >) {
511
+ ^bb0 (%a: f32 , %b: f32 ):
512
+ %d = arith.addf %a , %b : f32
513
+ linalg.yield %d : f32
514
+ } -> tensor <?xf32 >
515
+
516
+ %1 = linalg.generic {
517
+ indexing_maps = [#map3 , #map3 ], iterator_types = [" parallel" ]
518
+ } ins (%0 : tensor <?xf32 >) outs (%out_1 : tensor <?xf32 >) {
519
+ ^bb0 (%a: f32 , %b: f32 ):
520
+ %d = arith.mulf %a , %b : f32
521
+ linalg.yield %d : f32
522
+ } -> tensor <?xf32 >
523
+ %d0 = tensor.dim %out_1 , %c0 : tensor <?xf32 >
524
+
525
+ %2 = affine.apply #map0 ()[%d0 , %idx ]
526
+
527
+ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
528
+ // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
529
+ %3 = scf.forall (%i ) in (%2 ) shared_outs (%o = %out_2 ) -> (tensor <?xf32 >) {
530
+ // CHECK: %[[I0:.*]] = affine.apply {{.*}}
531
+ %4 = affine.apply #map1 (%i )[%idx ]
532
+ // CHECK: %[[I1:.*]] = affine.min {{.*}}
533
+ %5 = affine.min #map2 (%i )[%d0 , %idx ]
534
+ %6 = tensor.extract_slice %o [%4 ] [%5 ] [1 ] : tensor <?xf32 > to tensor <?xf32 >
535
+
536
+ // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
537
+ // CHECK: %[[T2:.*]] = linalg.generic {{.*}}
538
+ %7 = tensor.extract_slice %1 [%4 ] [%5 ] [1 ] : tensor <?xf32 > to tensor <?xf32 >
539
+
540
+ %8 = linalg.elemwise_unary ins (%7 : tensor <?xf32 >) outs (%6 : tensor <?xf32 >) -> tensor <?xf32 >
541
+ scf.forall.in_parallel {
542
+ // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
543
+ tensor.parallel_insert_slice %8 into %o [%2 ] [%5 ] [1 ] : tensor <?xf32 > into tensor <?xf32 >
544
+ }
545
+ }
546
+ // CHECK: return %[[R0]]#0, %[[R0]]#1
547
+ func.return %3 , %1 : tensor <?xf32 >, tensor <?xf32 >
548
+ // CHECK: }
549
+ }
550
+
551
+ transform.sequence failures (propagate ) {
552
+ ^bb1 (%arg1: !transform.any_op ):
553
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.op <" linalg.generic" >
554
+ %add , %reduce = transform.split_handle %0 : (!transform.op <" linalg.generic" >) -> (!transform.op <" linalg.generic" >, !transform.op <" linalg.generic" >)
555
+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg1 : (!transform.any_op ) -> !transform.op <" scf.forall" >
556
+
557
+ %fused_ops , %new_forall = transform.structured.fuse_into_containing_op %reduce into %1
558
+ : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> (!transform.any_op , !transform.op <" scf.forall" >)
559
+ %fused_ops_2 , %new_forall_2 = transform.structured.fuse_into_containing_op %add into %new_forall
560
+ : (!transform.op <" linalg.generic" >, !transform.op <" scf.forall" >) -> (!transform.any_op , !transform.op <" scf.forall" >)
486
561
}
487
562
}
0 commit comments