@@ -387,3 +387,125 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
387
387
// CHECK-LABEL: func @do_not_fuse_alias
388
388
// CHECK: scf.parallel
389
389
// CHECK: scf.parallel
390
+
391
+ // -----
392
+
393
+ func.func @fuse_reductions (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
394
+ %c2 = arith.constant 2 : index
395
+ %c0 = arith.constant 0 : index
396
+ %c1 = arith.constant 1 : index
397
+ %init1 = arith.constant 1.0 : f32
398
+ %init2 = arith.constant 2.0 : f32
399
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
400
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
401
+ scf.reduce (%A_elem ) : f32 {
402
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
403
+ %1 = arith.addf %lhs , %rhs : f32
404
+ scf.reduce.return %1 : f32
405
+ }
406
+ scf.yield
407
+ }
408
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
409
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
410
+ scf.reduce (%B_elem ) : f32 {
411
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
412
+ %1 = arith.mulf %lhs , %rhs : f32
413
+ scf.reduce.return %1 : f32
414
+ }
415
+ scf.yield
416
+ }
417
+ return %res1 , %res2 : f32 , f32
418
+ }
419
+
420
+ // CHECK-LABEL: func @fuse_reductions
421
+ // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>)
422
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
423
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
424
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
425
+ // CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
426
+ // CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
427
+ // CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
428
+ // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
429
+ // CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
430
+ // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
431
+ // CHECK: scf.reduce(%[[VAL_A]]) : f32 {
432
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
433
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
434
+ // CHECK: scf.reduce.return %[[R]] : f32
435
+ // CHECK: }
436
+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
437
+ // CHECK: scf.reduce(%[[VAL_B]]) : f32 {
438
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
439
+ // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
440
+ // CHECK: scf.reduce.return %[[R]] : f32
441
+ // CHECK: }
442
+ // CHECK: scf.yield
443
+ // CHECK: return %[[RES]]#0, %[[RES]]#1
444
+
445
+ // -----
446
+
447
+ func.func @reductions_use_res (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
448
+ %c2 = arith.constant 2 : index
449
+ %c0 = arith.constant 0 : index
450
+ %c1 = arith.constant 1 : index
451
+ %init1 = arith.constant 1.0 : f32
452
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
453
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
454
+ scf.reduce (%A_elem ) : f32 {
455
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
456
+ %1 = arith.addf %lhs , %rhs : f32
457
+ scf.reduce.return %1 : f32
458
+ }
459
+ scf.yield
460
+ }
461
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%res1 ) -> f32 {
462
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
463
+ scf.reduce (%B_elem ) : f32 {
464
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
465
+ %1 = arith.mulf %lhs , %rhs : f32
466
+ scf.reduce.return %1 : f32
467
+ }
468
+ scf.yield
469
+ }
470
+ return %res1 , %res2 : f32 , f32
471
+ }
472
+
473
+ // %res1 is used as second scf.parallel arg, cannot fuse
474
+ // CHECK-LABEL: func @reductions_use_res
475
+ // CHECK: scf.parallel
476
+ // CHECK: scf.parallel
477
+
478
+ // -----
479
+
480
+ func.func @reductions_use_res_inside (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
481
+ %c2 = arith.constant 2 : index
482
+ %c0 = arith.constant 0 : index
483
+ %c1 = arith.constant 1 : index
484
+ %init1 = arith.constant 1.0 : f32
485
+ %init2 = arith.constant 2.0 : f32
486
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
487
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
488
+ scf.reduce (%A_elem ) : f32 {
489
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
490
+ %1 = arith.addf %lhs , %rhs : f32
491
+ scf.reduce.return %1 : f32
492
+ }
493
+ scf.yield
494
+ }
495
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
496
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
497
+ %sum = arith.addf %B_elem , %res1 : f32
498
+ scf.reduce (%sum ) : f32 {
499
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
500
+ %1 = arith.mulf %lhs , %rhs : f32
501
+ scf.reduce.return %1 : f32
502
+ }
503
+ scf.yield
504
+ }
505
+ return %res1 , %res2 : f32 , f32
506
+ }
507
+
508
+ // %res1 is used inside second scf.parallel arg, cannot fuse
509
+ // CHECK-LABEL: func @reductions_use_res_inside
510
+ // CHECK: scf.parallel
511
+ // CHECK: scf.parallel
0 commit comments