Skip to content

Commit f50c6aa

Browse files
committed
add more lit tests for scf.parallel
1 parent c6847ec commit f50c6aa

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,116 @@ module attributes {transform.with_named_sequence} {
100100

101101
// -----
102102

103+
// CHECK-LABEL: func @fuse_two_parallel_reverse
104+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
105+
func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
106+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
107+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
108+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
109+
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
110+
%c2 = arith.constant 2 : index
111+
%c0 = arith.constant 0 : index
112+
%c1 = arith.constant 1 : index
113+
%c1fp = arith.constant 1.0 : f32
114+
// CHECK: [[SUM:%.*]] = memref.alloc()
115+
%sum = memref.alloc() : memref<2x2xf32>
116+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
117+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
118+
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
119+
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
120+
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
121+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
122+
// CHECK-NOT: scf.parallel
123+
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
124+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
125+
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
126+
// CHECK: scf.reduce
127+
// CHECK: }
128+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
129+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
130+
%sum_elem = arith.addf %B_elem, %c1fp : f32
131+
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
132+
scf.reduce
133+
}
134+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
135+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
136+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
137+
%product_elem = arith.mulf %sum_elem, %A_elem : f32
138+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
139+
scf.reduce
140+
}
141+
// CHECK: memref.dealloc [[SUM]]
142+
memref.dealloc %sum : memref<2x2xf32>
143+
return
144+
}
145+
module attributes {transform.with_named_sequence} {
146+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
147+
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
148+
%parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
149+
%fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op
150+
transform.yield
151+
}
152+
}
153+
154+
// -----
155+
156+
// CHECK-LABEL: func @fuse_reductions_two
157+
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
158+
func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
159+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
160+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
161+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
162+
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
163+
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
164+
// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
165+
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
166+
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
167+
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
168+
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
169+
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
170+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
171+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
172+
// CHECK: scf.reduce.return %[[R]] : f32
173+
// CHECK: }
174+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
175+
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
176+
// CHECK: scf.reduce.return %[[R]] : f32
177+
// CHECK: }
178+
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
179+
%c2 = arith.constant 2 : index
180+
%c0 = arith.constant 0 : index
181+
%c1 = arith.constant 1 : index
182+
%init1 = arith.constant 1.0 : f32
183+
%init2 = arith.constant 2.0 : f32
184+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
185+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
186+
scf.reduce(%A_elem : f32) {
187+
^bb0(%lhs: f32, %rhs: f32):
188+
%1 = arith.addf %lhs, %rhs : f32
189+
scf.reduce.return %1 : f32
190+
}
191+
}
192+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
193+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
194+
scf.reduce(%B_elem : f32) {
195+
^bb0(%lhs: f32, %rhs: f32):
196+
%1 = arith.mulf %lhs, %rhs : f32
197+
scf.reduce.return %1 : f32
198+
}
199+
}
200+
return %res1, %res2 : f32, f32
201+
}
202+
module attributes {transform.with_named_sequence} {
203+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
204+
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
205+
%parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
206+
%fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
207+
transform.yield
208+
}
209+
}
210+
211+
// -----
212+
103213
// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
104214
func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
105215
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
@@ -382,3 +492,37 @@ module attributes {transform.with_named_sequence} {
382492
transform.yield
383493
}
384494
}
495+
496+
// -----
497+
498+
func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
499+
%c2 = arith.constant 2 : index
500+
%c0 = arith.constant 0 : index
501+
%c1 = arith.constant 1 : index
502+
%c1fp = arith.constant 1.0 : f32
503+
%sum = memref.alloc() : memref<2x2xf32>
504+
// expected-error @below {{target and source iteration spaces must be equal}}
505+
scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
506+
%B_elem = memref.load %B[%i, %c0] : memref<2x2xf32>
507+
%sum_elem = arith.addf %B_elem, %c1fp : f32
508+
memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32>
509+
scf.reduce
510+
}
511+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
512+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
513+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
514+
%product_elem = arith.mulf %sum_elem, %A_elem : f32
515+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
516+
scf.reduce
517+
}
518+
memref.dealloc %sum : memref<2x2xf32>
519+
return
520+
}
521+
module attributes {transform.with_named_sequence} {
522+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
523+
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
524+
%parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
525+
%fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
526+
transform.yield
527+
}
528+
}

0 commit comments

Comments
 (0)