Skip to content

Commit b8c974f

Browse files
[MLIR][TilingInterface] Extend consumer fusion for multi-use of producer shared by terminator ops (#110105)
-- This commit extends consumer fusion to take place even if the producer has multiple uses. -- The multiple uses of the producer essentially means that besides the consumer op in concern, the only other uses of the producer are allowed in :- 1. scf.yield 2. tensor.parallel_insert_slice Signed-off-by: Abhishek Varma <[email protected]>
1 parent 8e0daab commit b8c974f

File tree

2 files changed

+94
-15
lines changed

2 files changed

+94
-15
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,21 +1481,29 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
14811481
/// failure otherwise.
14821482
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
14831483
Block *containingOpBlock) {
1484-
// Step 1. Check that the value has exactly one use.
1485-
if (!llvm::hasSingleElement(val.getUses()))
1486-
return failure();
1487-
// Step 2. Get uses.
1488-
OpOperand &operand = (*val.getUses().begin());
1489-
Operation *consumerOp = operand.getOwner();
1490-
// TODO: We have to init result of consumer before scf.for, use
1491-
// DestinationStyleOpInterface to get result shape from init for now.
1492-
// Add support for other op such as op has InferTypeOpInterface.
1493-
if (!isa<TilingInterface>(consumerOp) ||
1494-
!isa<DestinationStyleOpInterface>(consumerOp))
1495-
return failure();
1496-
if (containingOpBlock != consumerOp->getBlock())
1497-
return failure();
1498-
return &operand;
1484+
// Check that the value has exactly one use which isn't a scf.yield or a
1485+
// tensor.parallel_insert_slice op.
1486+
OpOperand *operand = nullptr;
1487+
for (OpOperand &opOperand : val.getUses()) {
1488+
Operation *consumerOp = opOperand.getOwner();
1489+
if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
1490+
continue;
1491+
if (operand)
1492+
return failure();
1493+
// TODO: We have to init result of consumer before scf.for, use
1494+
// DestinationStyleOpInterface to get result shape from init for now.
1495+
// Add support for other op such as op has InferTypeOpInterface.
1496+
if (!isa<TilingInterface>(consumerOp) ||
1497+
!isa<DestinationStyleOpInterface>(consumerOp))
1498+
return failure();
1499+
if (containingOpBlock != consumerOp->getBlock())
1500+
return failure();
1501+
operand = &opOperand;
1502+
}
1503+
1504+
if (operand)
1505+
return operand;
1506+
return failure();
14991507
}
15001508

15011509
/// Find the perfectly nested loops outside of given loop(included) sorted from

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,74 @@ module attributes {transform.with_named_sequence} {
437437
// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
438438
// CHECK: }
439439
// CHECK: return %[[LOOP_RESULT1]]#1 :
440+
441+
// -----
442+
443+
// This test case checks fusion of consumer even if the producer has multiple uses.
444+
// The multiple uses of the producer essentially means that besides the consumer
445+
// op in concern, the only other uses of the producer are allowed in :-
446+
// 1. scf.yield
447+
// 2. tensor.parallel_insert_slice
448+
449+
module {
450+
module {
451+
func.func @fuse_consumer_for_multi_use_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
452+
%c0 = arith.constant 0 : index
453+
%c64 = arith.constant 64 : index
454+
%c256 = arith.constant 256 : index
455+
%cst = arith.constant 0.000000e+00 : f32
456+
%0 = tensor.empty() : tensor<256x256xf32>
457+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
458+
%2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %1, %arg5 = %arg2) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
459+
%3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args(%arg7 = %arg4) -> (tensor<256x256xf32>) {
460+
%extracted_slice = tensor.extract_slice %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
461+
%extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
462+
%extracted_slice_1 = tensor.extract_slice %arg1[0, %arg6] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
463+
%5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice : tensor<64x64xf32>) -> tensor<64x64xf32>
464+
%inserted_slice = tensor.insert_slice %5 into %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
465+
scf.yield %inserted_slice : tensor<256x256xf32>
466+
}
467+
%4 = linalg.add ins(%3, %arg5 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
468+
scf.yield %3, %4 : tensor<256x256xf32>, tensor<256x256xf32>
469+
}
470+
return %2#0, %2#1 : tensor<256x256xf32>, tensor<256x256xf32>
471+
}
472+
}
473+
module attributes {transform.with_named_sequence} {
474+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
475+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
476+
%consumer, %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
477+
transform.yield
478+
}
479+
}
480+
}
481+
// CHECK: func.func @fuse_consumer_for_multi_use_producer(
482+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
483+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
484+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
485+
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
486+
// CHECK: %[[dest1:.*]] = linalg.fill
487+
// CHECK-SAME: outs(%[[dest0]] :
488+
// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
489+
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]])
490+
// CHECK-SAME: {
491+
// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
492+
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]])
493+
// CHECK-SAME: {
494+
// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
495+
// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
496+
// CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
497+
// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
498+
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
499+
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
500+
// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
501+
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
502+
// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
503+
// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
504+
// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
505+
// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
506+
// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
507+
// CHECK: }
508+
// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
509+
// CHECK: }
510+
// CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :

0 commit comments

Comments
 (0)