@@ -100,6 +100,116 @@ module attributes {transform.with_named_sequence} {
100
100
101
101
// -----
102
102
103
+ // CHECK-LABEL: func @fuse_two_parallel_reverse
104
+ // CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
105
+ func.func @fuse_two_parallel_reverse (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
106
+ // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
107
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
108
+ // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
109
+ // CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
110
+ %c2 = arith.constant 2 : index
111
+ %c0 = arith.constant 0 : index
112
+ %c1 = arith.constant 1 : index
113
+ %c1fp = arith.constant 1.0 : f32
114
+ // CHECK: [[SUM:%.*]] = memref.alloc()
115
+ %sum = memref.alloc () : memref <2 x2 xf32 >
116
+ // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
117
+ // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
118
+ // CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
119
+ // CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
120
+ // CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
121
+ // CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
122
+ // CHECK-NOT: scf.parallel
123
+ // CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
124
+ // CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
125
+ // CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
126
+ // CHECK: scf.reduce
127
+ // CHECK: }
128
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
129
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
130
+ %sum_elem = arith.addf %B_elem , %c1fp : f32
131
+ memref.store %sum_elem , %sum [%i , %j ] : memref <2 x2 xf32 >
132
+ scf.reduce
133
+ }
134
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
135
+ %sum_elem = memref.load %sum [%i , %j ] : memref <2 x2 xf32 >
136
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
137
+ %product_elem = arith.mulf %sum_elem , %A_elem : f32
138
+ memref.store %product_elem , %B [%i , %j ] : memref <2 x2 xf32 >
139
+ scf.reduce
140
+ }
141
+ // CHECK: memref.dealloc [[SUM]]
142
+ memref.dealloc %sum : memref <2 x2 xf32 >
143
+ return
144
+ }
145
+ module attributes {transform.with_named_sequence } {
146
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
147
+ %0 = transform.structured.match ops {[" scf.parallel" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
148
+ %parallel:2 = transform.split_handle %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
149
+ %fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op ,!transform.any_op ) -> !transform.any_op
150
+ transform.yield
151
+ }
152
+ }
153
+
154
+ // -----
155
+
156
+ // CHECK-LABEL: func @fuse_reductions_two
157
+ // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
158
+ func.func @fuse_reductions_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
159
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
160
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
161
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
162
+ // CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
163
+ // CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
164
+ // CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
165
+ // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
166
+ // CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
167
+ // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
168
+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
169
+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
170
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
171
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
172
+ // CHECK: scf.reduce.return %[[R]] : f32
173
+ // CHECK: }
174
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
175
+ // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
176
+ // CHECK: scf.reduce.return %[[R]] : f32
177
+ // CHECK: }
178
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
179
+ %c2 = arith.constant 2 : index
180
+ %c0 = arith.constant 0 : index
181
+ %c1 = arith.constant 1 : index
182
+ %init1 = arith.constant 1.0 : f32
183
+ %init2 = arith.constant 2.0 : f32
184
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
185
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
186
+ scf.reduce (%A_elem : f32 ) {
187
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
188
+ %1 = arith.addf %lhs , %rhs : f32
189
+ scf.reduce.return %1 : f32
190
+ }
191
+ }
192
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
193
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
194
+ scf.reduce (%B_elem : f32 ) {
195
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
196
+ %1 = arith.mulf %lhs , %rhs : f32
197
+ scf.reduce.return %1 : f32
198
+ }
199
+ }
200
+ return %res1 , %res2 : f32 , f32
201
+ }
202
+ module attributes {transform.with_named_sequence } {
203
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
204
+ %0 = transform.structured.match ops {[" scf.parallel" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
205
+ %parallel:2 = transform.split_handle %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
206
+ %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op ,!transform.any_op ) -> !transform.any_op
207
+ transform.yield
208
+ }
209
+ }
210
+
211
+ // -----
212
+
103
213
// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
104
214
func.func @fuse_2nd_for_into_1st (%A: tensor <128 xf32 >, %B: tensor <128 xf32 >) -> (tensor <128 xf32 >, tensor <128 xf32 >) {
105
215
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
@@ -382,3 +492,37 @@ module attributes {transform.with_named_sequence} {
382
492
transform.yield
383
493
}
384
494
}
495
+
496
+ // -----
497
+
498
+ func.func @non_matching_iteration_spaces_err (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
499
+ %c2 = arith.constant 2 : index
500
+ %c0 = arith.constant 0 : index
501
+ %c1 = arith.constant 1 : index
502
+ %c1fp = arith.constant 1.0 : f32
503
+ %sum = memref.alloc () : memref <2 x2 xf32 >
504
+ // expected-error @below {{target and source iteration spaces must be equal}}
505
+ scf.parallel (%i ) = (%c0 ) to (%c2 ) step (%c1 ) {
506
+ %B_elem = memref.load %B [%i , %c0 ] : memref <2 x2 xf32 >
507
+ %sum_elem = arith.addf %B_elem , %c1fp : f32
508
+ memref.store %sum_elem , %sum [%i , %c0 ] : memref <2 x2 xf32 >
509
+ scf.reduce
510
+ }
511
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
512
+ %sum_elem = memref.load %sum [%i , %j ] : memref <2 x2 xf32 >
513
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
514
+ %product_elem = arith.mulf %sum_elem , %A_elem : f32
515
+ memref.store %product_elem , %B [%i , %j ] : memref <2 x2 xf32 >
516
+ scf.reduce
517
+ }
518
+ memref.dealloc %sum : memref <2 x2 xf32 >
519
+ return
520
+ }
521
+ module attributes {transform.with_named_sequence } {
522
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
523
+ %0 = transform.structured.match ops {[" scf.parallel" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
524
+ %parallel:2 = transform.split_handle %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
525
+ %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op ,!transform.any_op ) -> !transform.any_op
526
+ transform.yield
527
+ }
528
+ }
0 commit comments