@@ -206,6 +206,106 @@ module {
206
206
#map1 = affine_map <(d0 )[s0 ] -> (d0 * s0 )>
207
207
#map2 = affine_map <(d0 )[s0 , s1 ] -> (-(d0 * s1 ) + s0 , s1 )>
208
208
209
+ module {
210
+ // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout
211
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
212
+ // CHECK-SAME: %[[INOUT:[0-9a-z]+]]: tensor<?xf32>
213
+ func.func @fuse_tileable_op_through_bbarg_inout (%arg0: index , %arg1: tensor <?xf32 >) -> tensor <?xf32 > {
214
+ %cst = arith.constant 4.200000e+01 : f32
215
+ %c0 = arith.constant 0 : index
216
+ %0 = linalg.fill ins (%cst : f32 ) outs (%arg1 : tensor <?xf32 >) -> tensor <?xf32 >
217
+ %d0 = tensor.dim %arg1 , %c0 : tensor <?xf32 >
218
+ %1 = affine.apply #map0 ()[%d0 , %arg0 ]
219
+
220
+ // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[INOUT]]) -> (tensor<?xf32>) {
221
+ %2 = scf.forall (%arg3 ) in (%1 ) shared_outs (%o = %arg1 ) -> (tensor <?xf32 >) {
222
+ %3 = affine.apply #map1 (%arg3 )[%arg0 ]
223
+ %4 = affine.min #map2 (%arg3 )[%d0 , %arg0 ]
224
+ %5 = tensor.extract_slice %o [%3 ] [%4 ] [1 ] : tensor <?xf32 > to tensor <?xf32 >
225
+
226
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
227
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
228
+ // CHECK: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]]
229
+ %6 = tensor.extract_slice %0 [%3 ] [%4 ] [1 ] : tensor <?xf32 > to tensor <?xf32 >
230
+
231
+ // CHECK: %[[T3:.*]] = linalg.elemwise_unary ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
232
+ %7 = linalg.elemwise_unary ins (%6 : tensor <?xf32 >) outs (%5 : tensor <?xf32 >) -> tensor <?xf32 >
233
+ scf.forall.in_parallel {
234
+ tensor.parallel_insert_slice %7 into %o [%3 ] [%4 ] [1 ] : tensor <?xf32 > into tensor <?xf32 >
235
+ }
236
+ }
237
+ // CHECK: }
238
+ func.return %2 : tensor <?xf32 >
239
+ }
240
+
241
+ module attributes {transform.with_named_sequence } {
242
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
243
+ %0 = transform.structured.match ops {[" linalg.fill" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
244
+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
245
+
246
+ // linalg.fill is tileable. The op is tiled and fused.
247
+ transform.structured.fuse_into_containing_op %0 into %1
248
+ : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
249
+ transform.yield
250
+ }
251
+ }
252
+ }
253
+
254
+ // -----
255
+
256
+ module {
257
+ // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
258
+ // CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>
259
+ // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x?xf32>
260
+ func.func @fuse_tileable_op_through_bbarg_inout_nested (%arg0: tensor <?x?x?xf32 >, %arg1: tensor <?x?x?xf32 >) -> tensor <?x?x?xf32 > {
261
+ %c2 = arith.constant 2 : index
262
+ %c1 = arith.constant 1 : index
263
+ %c0 = arith.constant 0 : index
264
+ %0 = linalg.elemwise_unary {fun = #linalg.unary_fn <abs >} ins (%arg0 : tensor <?x?x?xf32 >) outs (%arg1 : tensor <?x?x?xf32 >) -> tensor <?x?x?xf32 >
265
+ %dim = tensor.dim %arg1 , %c0 : tensor <?x?x?xf32 >
266
+ %dim_0 = tensor.dim %arg1 , %c1 : tensor <?x?x?xf32 >
267
+ %dim_1 = tensor.dim %arg1 , %c2 : tensor <?x?x?xf32 >
268
+ // CHECK: scf.for {{.*}} iter_args(%[[BBARG0:.*]] = %[[ARG1]]) -> (tensor<?x?x?xf32>) {
269
+ // CHECK: scf.for {{.*}} iter_args(%[[BBARG1:.*]] = %[[BBARG0]]) -> (tensor<?x?x?xf32>) {
270
+ // CHECK: scf.for {{.*}} iter_args(%[[BBARG2:.*]] = %[[BBARG1]]) -> (tensor<?x?x?xf32>) {
271
+ %1 = scf.for %arg2 = %c0 to %dim step %c1 iter_args (%arg3 = %arg1 ) -> (tensor <?x?x?xf32 >) {
272
+ %2 = scf.for %arg4 = %c0 to %dim_0 step %c1 iter_args (%arg5 = %arg3 ) -> (tensor <?x?x?xf32 >) {
273
+ %3 = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args (%arg7 = %arg5 ) -> (tensor <?x?x?xf32 >) {
274
+ // CHECK: %[[EX1:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}}: tensor<?x?x?xf32> to tensor<1x1x1xf32>
275
+ // CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX1]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
276
+ // CHECK: %[[EX2:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}} : tensor<?x?x?xf32> to tensor<1x1x1xf32>
277
+ // CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX2]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
278
+ %extracted_slice = tensor.extract_slice %0 [%arg2 , %arg4 , %arg6 ] [1 , 1 , 1 ] [1 , 1 , 1 ] : tensor <?x?x?xf32 > to tensor <1 x1 x1 xf32 >
279
+ %extracted_slice_2 = tensor.extract_slice %arg7 [%arg2 , %arg4 , %arg6 ] [1 , 1 , 1 ] [1 , 1 , 1 ] : tensor <?x?x?xf32 > to tensor <1 x1 x1 xf32 >
280
+ %4 = linalg.elemwise_unary {fun = #linalg.unary_fn <exp >} ins (%extracted_slice : tensor <1 x1 x1 xf32 >) outs (%extracted_slice_2 : tensor <1 x1 x1 xf32 >) -> tensor <1 x1 x1 xf32 >
281
+ %inserted_slice = tensor.insert_slice %4 into %arg7 [%arg2 , %arg4 , %arg6 ] [1 , 1 , 1 ] [1 , 1 , 1 ] : tensor <1 x1 x1 xf32 > into tensor <?x?x?xf32 >
282
+ scf.yield %inserted_slice : tensor <?x?x?xf32 >
283
+ }
284
+ scf.yield %3 : tensor <?x?x?xf32 >
285
+ }
286
+ scf.yield %2 : tensor <?x?x?xf32 >
287
+ }
288
+ return %1 : tensor <?x?x?xf32 >
289
+ }
290
+
291
+ module attributes {transform.with_named_sequence } {
292
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
293
+ %0 = transform.structured.match ops {[" linalg.elemwise_unary" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
294
+ %1 = transform.structured.match ops {[" scf.for" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
295
+ %2:2 = transform.split_handle %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
296
+ %3:3 = transform.split_handle %1 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op )
297
+ transform.structured.fuse_into_containing_op %2#0 into %3#2 : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
298
+ transform.yield
299
+ }
300
+ }
301
+ }
302
+
303
+ // -----
304
+
305
+ #map0 = affine_map <()[s0 , s1 ] -> (s0 ceildiv s1 )>
306
+ #map1 = affine_map <(d0 )[s0 ] -> (d0 * s0 )>
307
+ #map2 = affine_map <(d0 )[s0 , s1 ] -> (-(d0 * s1 ) + s0 , s1 )>
308
+
209
309
module {
210
310
// CHECK-LABEL: func.func @fuse_tileable_multi_output_op
211
311
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
0 commit comments