@@ -24,6 +24,32 @@ func.func @fuse_empty_loops() {
24
24
25
25
// -----
26
26
27
+ func.func @fuse_ops_between (%A: f32 , %B: f32 ) -> f32 {
28
+ %c2 = arith.constant 2 : index
29
+ %c0 = arith.constant 0 : index
30
+ %c1 = arith.constant 1 : index
31
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
32
+ scf.reduce
33
+ }
34
+ %res = arith.addf %A , %B : f32
35
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
36
+ scf.reduce
37
+ }
38
+ return %res : f32
39
+ }
40
+ // CHECK-LABEL: func @fuse_ops_between
41
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
42
+ // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
43
+ // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
44
+ // CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
45
+ // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
46
+ // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
47
+ // CHECK: scf.reduce
48
+ // CHECK: }
49
+ // CHECK-NOT: scf.parallel
50
+
51
+ // -----
52
+
27
53
func.func @fuse_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
28
54
%c2 = arith.constant 2 : index
29
55
%c0 = arith.constant 0 : index
@@ -754,3 +780,36 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
754
780
// CHECK-LABEL: func @reductions_use_res_inside
755
781
// CHECK: scf.parallel
756
782
// CHECK: scf.parallel
783
+
784
+ // -----
785
+
786
+ func.func @reductions_use_res_between (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 , f32 ) {
787
+ %c2 = arith.constant 2 : index
788
+ %c0 = arith.constant 0 : index
789
+ %c1 = arith.constant 1 : index
790
+ %init1 = arith.constant 1.0 : f32
791
+ %init2 = arith.constant 2.0 : f32
792
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
793
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
794
+ scf.reduce (%A_elem : f32 ) {
795
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
796
+ %1 = arith.addf %lhs , %rhs : f32
797
+ scf.reduce.return %1 : f32
798
+ }
799
+ }
800
+ %res3 = arith.addf %res1 , %init2 : f32
801
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
802
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
803
+ scf.reduce (%B_elem : f32 ) {
804
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
805
+ %1 = arith.mulf %lhs , %rhs : f32
806
+ scf.reduce.return %1 : f32
807
+ }
808
+ }
809
+ return %res1 , %res2 , %res3 : f32 , f32 , f32
810
+ }
811
+
812
+ // instruction in between the loops uses the first loop result
813
+ // CHECK-LABEL: func @reductions_use_res_between
814
+ // CHECK: scf.parallel
815
+ // CHECK: scf.parallel
0 commit comments