Skip to content

Commit 8f7b4a4

Browse files
committed
more tests
1 parent 083707e commit 8f7b4a4

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,32 @@ func.func @fuse_empty_loops() {
2424

2525
// -----
2626

27+
func.func @fuse_ops_between(%A: f32, %B: f32) -> f32 {
28+
%c2 = arith.constant 2 : index
29+
%c0 = arith.constant 0 : index
30+
%c1 = arith.constant 1 : index
31+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
32+
scf.reduce
33+
}
34+
%res = arith.addf %A, %B : f32
35+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
36+
scf.reduce
37+
}
38+
return %res : f32
39+
}
40+
// CHECK-LABEL: func @fuse_ops_between
41+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
42+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
43+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
44+
// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
45+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
46+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
47+
// CHECK: scf.reduce
48+
// CHECK: }
49+
// CHECK-NOT: scf.parallel
50+
51+
// -----
52+
2753
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
2854
%c2 = arith.constant 2 : index
2955
%c0 = arith.constant 0 : index
@@ -754,3 +780,36 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
754780
// CHECK-LABEL: func @reductions_use_res_inside
755781
// CHECK: scf.parallel
756782
// CHECK: scf.parallel
783+
784+
// -----
785+
786+
func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32, f32) {
787+
%c2 = arith.constant 2 : index
788+
%c0 = arith.constant 0 : index
789+
%c1 = arith.constant 1 : index
790+
%init1 = arith.constant 1.0 : f32
791+
%init2 = arith.constant 2.0 : f32
792+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
793+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
794+
scf.reduce(%A_elem : f32) {
795+
^bb0(%lhs: f32, %rhs: f32):
796+
%1 = arith.addf %lhs, %rhs : f32
797+
scf.reduce.return %1 : f32
798+
}
799+
}
800+
%res3 = arith.addf %res1, %init2 : f32
801+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
802+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
803+
scf.reduce(%B_elem : f32) {
804+
^bb0(%lhs: f32, %rhs: f32):
805+
%1 = arith.mulf %lhs, %rhs : f32
806+
scf.reduce.return %1 : f32
807+
}
808+
}
809+
return %res1, %res2, %res3 : f32, f32, f32
810+
}
811+
812+
// instruction in between the loops uses the first loop result
813+
// CHECK-LABEL: func @reductions_use_res_between
814+
// CHECK: scf.parallel
815+
// CHECK: scf.parallel

0 commit comments

Comments
 (0)