@@ -437,3 +437,74 @@ module attributes {transform.with_named_sequence} {
437
437
// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
438
438
// CHECK: }
439
439
// CHECK: return %[[LOOP_RESULT1]]#1 :
440
+
441
+ // -----
442
+
443
+ // This test case checks fusion of consumer even if the producer has multiple uses.
444
+ // The multiple uses of the producer essentially means that besides the consumer
445
+ // op in concern, the only other uses of the producer are allowed in :-
446
+ // 1. scf.yield
447
+ // 2. tensor.parallel_insert_slice
448
+
449
+ module {
450
+ module {
451
+ func.func @fuse_consumer_for_multi_use_producer (%arg0: tensor <256 x512 xf32 >, %arg1: tensor <512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> (tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) {
452
+ %c0 = arith.constant 0 : index
453
+ %c64 = arith.constant 64 : index
454
+ %c256 = arith.constant 256 : index
455
+ %cst = arith.constant 0.000000e+00 : f32
456
+ %0 = tensor.empty () : tensor <256 x256 xf32 >
457
+ %1 = linalg.fill ins (%cst : f32 ) outs (%0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
458
+ %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args (%arg4 = %1 , %arg5 = %arg2 ) -> (tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) {
459
+ %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args (%arg7 = %arg4 ) -> (tensor <256 x256 xf32 >) {
460
+ %extracted_slice = tensor.extract_slice %arg7 [%arg3 , %arg6 ] [64 , 64 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x64 xf32 >
461
+ %extracted_slice_0 = tensor.extract_slice %arg0 [%arg3 , 0 ] [64 , 512 ] [1 , 1 ] : tensor <256 x512 xf32 > to tensor <64 x512 xf32 >
462
+ %extracted_slice_1 = tensor.extract_slice %arg1 [0 , %arg6 ] [512 , 64 ] [1 , 1 ] : tensor <512 x256 xf32 > to tensor <512 x64 xf32 >
463
+ %5 = linalg.matmul ins (%extracted_slice_0 , %extracted_slice_1 : tensor <64 x512 xf32 >, tensor <512 x64 xf32 >) outs (%extracted_slice : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
464
+ %inserted_slice = tensor.insert_slice %5 into %arg7 [%arg3 , %arg6 ] [64 , 64 ] [1 , 1 ] : tensor <64 x64 xf32 > into tensor <256 x256 xf32 >
465
+ scf.yield %inserted_slice : tensor <256 x256 xf32 >
466
+ }
467
+ %4 = linalg.add ins (%3 , %arg5 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
468
+ scf.yield %3 , %4 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >
469
+ }
470
+ return %2#0 , %2#1 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >
471
+ }
472
+ }
473
+ module attributes {transform.with_named_sequence } {
474
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
475
+ %0 = transform.structured.match ops {[" tensor.insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
476
+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
477
+ transform.yield
478
+ }
479
+ }
480
+ }
481
+ // CHECK: func.func @fuse_consumer_for_multi_use_producer(
482
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
483
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
484
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
485
+ // CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
486
+ // CHECK: %[[dest1:.*]] = linalg.fill
487
+ // CHECK-SAME: outs(%[[dest0]] :
488
+ // CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
489
+ // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]])
490
+ // CHECK-SAME: {
491
+ // CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
492
+ // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]])
493
+ // CHECK-SAME: {
494
+ // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
495
+ // CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
496
+ // CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
497
+ // CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
498
+ // CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
499
+ // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
500
+ // CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
501
+ // CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
502
+ // CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
503
+ // CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
504
+ // CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
505
+ // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
506
+ // CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
507
+ // CHECK: }
508
+ // CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
509
+ // CHECK: }
510
+ // CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
0 commit comments