Skip to content

Commit ba5591e

Browse files
[mlir][Transform] Reuse bbArgs in FuseIntoContainingOp (#135066)
When fusing two ops with the same output operand using FuseIntoContainingOp, the current implementation makes both ops write into a different value pointing to the same tensor. This, in the end, will bufferize into two different buffers, which is sub-optimal. The current patch solves this problem, adding support to reuse the tensor by both consumer and producer. More precisely, before FuseIntoContainingOp is applied, we may have two ops that write into the same output tensor. However, the consumer would be tiled, thus the op would write into the loop iter_args (i.e., it does not write directly into the original tensor). When the producer is fused into the loop, the output tensor of the producer remains the same, so the consumer and producer writes into two different values (consumer writes into the iter_args and producer into the original tensor). The current patch clones the consumer into the loop and checks if the consumer is writing to the same value pointed by the loop inits, in which case, it makes the output point to such tensor.
1 parent c4f7ab1 commit ba5591e

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,54 @@ static Operation *replaceForAllWithNewSignature(
718718
return newforallOp;
719719
}
720720

721+
/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
722+
/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
723+
/// outer loop. To determine the second condition, this function iterates
724+
/// using a worklist over the enclosing loops, trying to find 'src' in any of
725+
/// the parent loop's iter args.
726+
static bool sameOrEquivalentIterArg(Value src, Value dst) {
727+
// Stack like vector containing possible iterArgs candidates. The first one
728+
// is dst, and we will transverse the IR from there.
729+
SmallVector<Value> destWorklist;
730+
destWorklist.push_back(dst);
731+
732+
while (!destWorklist.empty()) {
733+
Value currentDst = destWorklist.pop_back_val();
734+
735+
// We have found the same operand in some iter arg in the loop structure,
736+
// so src and dst are equivalent.
737+
if (src == currentDst)
738+
return true;
739+
740+
// The operands are not equivalent, look for enclosing loops over
741+
// currentDst.
742+
auto bbArg = dyn_cast<BlockArgument>(currentDst);
743+
if (!bbArg)
744+
continue;
745+
746+
Block *parentBlock = bbArg.getOwner();
747+
assert(parentBlock && "unlinked block argument");
748+
749+
Operation *parentOp = parentBlock->getParentOp();
750+
assert(parentOp && "expected block argument with parent operation");
751+
752+
// Check if parent is loop-like. If it's not, do not add it to the worklist.
753+
auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
754+
if (!parentLoop)
755+
continue;
756+
757+
for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
758+
// No need to check for null as innerIterArg is tied to parentLoop.
759+
OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
760+
Value loopBlockArgument =
761+
parentLoop->getOperand(operand->getOperandNumber());
762+
destWorklist.push_back(loopBlockArgument);
763+
}
764+
}
765+
766+
return false;
767+
}
768+
721769
/// Find the first "extract" user of `producerOp` and tile it right before its
722770
/// use. The tiled op is fused under the `containingOp`.
723771
/// Return this fused op on success or nullptr if anything fails.
@@ -755,6 +803,40 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
755803
OpBuilder::InsertionGuard guard(rewriter);
756804
rewriter.setInsertionPoint(sliceOpToTile);
757805

806+
// Clone the producer inside the consumer and try to update the producer init
807+
// operands using the loop bbArgs if applicable. More precisely, if the bbArg
808+
// of the container loop points to a value that it is used by the consumer op,
809+
// then, instead of using such value on the consumer, use the value coming
810+
// from the bbArg instead. This allows to reuse the output tensor (instead of
811+
// creating a new one) of the container when both producer and container write
812+
// to the same output.
813+
if (LoopLikeOpInterface containerLoop =
814+
dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
815+
Operation *clone = rewriter.clone(*producerOp);
816+
rewriter.modifyOpInPlace(clone, [&]() {
817+
// Iterate over the outputs of the producer and over the loop bbArgs and
818+
// check if any bbArg points to the same value as the producer output. In
819+
// such case, make the producer output point to the bbArg directly.
820+
for (OpOperand &initOperandPtr :
821+
cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
822+
Value producerOperand =
823+
clone->getOperand(initOperandPtr.getOperandNumber());
824+
for (BlockArgument containerIterArg :
825+
containerLoop.getRegionIterArgs()) {
826+
OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
827+
Value consumerOperand =
828+
containerLoop->getOperand(bbArg->getOperandNumber());
829+
// The producer has the same init as the loop bbArg, use it.
830+
if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
831+
initOperandPtr.set(containerIterArg);
832+
}
833+
}
834+
}
835+
});
836+
837+
tileableProducer = dyn_cast<TilingInterface>(clone);
838+
}
839+
758840
// Tile the producer.
759841
int64_t resultNumber =
760842
cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
@@ -797,6 +879,10 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
797879
rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
798880
resultNumber, offsets, sizes);
799881

882+
// Cleanup clone.
883+
if (dyn_cast<LoopLikeOpInterface>(containingOp))
884+
rewriter.eraseOp(tileableProducer);
885+
800886
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
801887
}
802888

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,106 @@ module {
206206
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
207207
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
208208

209+
module {
210+
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout
211+
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
212+
// CHECK-SAME: %[[INOUT:[0-9a-z]+]]: tensor<?xf32>
213+
func.func @fuse_tileable_op_through_bbarg_inout(%arg0: index, %arg1: tensor<?xf32>) -> tensor<?xf32> {
214+
%cst = arith.constant 4.200000e+01 : f32
215+
%c0 = arith.constant 0 : index
216+
%0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
217+
%d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
218+
%1 = affine.apply #map0()[%d0, %arg0]
219+
220+
// CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[INOUT]]) -> (tensor<?xf32>) {
221+
%2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg1) -> (tensor<?xf32>) {
222+
%3 = affine.apply #map1(%arg3)[%arg0]
223+
%4 = affine.min #map2(%arg3)[%d0, %arg0]
224+
%5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
225+
226+
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
227+
// CHECK: %[[T1:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
228+
// CHECK: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]]
229+
%6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
230+
231+
// CHECK: %[[T3:.*]] = linalg.elemwise_unary ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
232+
%7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
233+
scf.forall.in_parallel {
234+
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
235+
}
236+
}
237+
// CHECK: }
238+
func.return %2 : tensor<?xf32>
239+
}
240+
241+
module attributes {transform.with_named_sequence} {
242+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
243+
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
244+
%1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
245+
246+
// linalg.fill is tileable. The op is tiled and fused.
247+
transform.structured.fuse_into_containing_op %0 into %1
248+
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
249+
transform.yield
250+
}
251+
}
252+
}
253+
254+
// -----
255+
256+
module {
257+
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
258+
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>
259+
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x?xf32>
260+
func.func @fuse_tileable_op_through_bbarg_inout_nested(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
261+
%c2 = arith.constant 2 : index
262+
%c1 = arith.constant 1 : index
263+
%c0 = arith.constant 0 : index
264+
%0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
265+
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
266+
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
267+
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
268+
// CHECK: scf.for {{.*}} iter_args(%[[BBARG0:.*]] = %[[ARG1]]) -> (tensor<?x?x?xf32>) {
269+
// CHECK: scf.for {{.*}} iter_args(%[[BBARG1:.*]] = %[[BBARG0]]) -> (tensor<?x?x?xf32>) {
270+
// CHECK: scf.for {{.*}} iter_args(%[[BBARG2:.*]] = %[[BBARG1]]) -> (tensor<?x?x?xf32>) {
271+
%1 = scf.for %arg2 = %c0 to %dim step %c1 iter_args(%arg3 = %arg1) -> (tensor<?x?x?xf32>) {
272+
%2 = scf.for %arg4 = %c0 to %dim_0 step %c1 iter_args(%arg5 = %arg3) -> (tensor<?x?x?xf32>) {
273+
%3 = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args(%arg7 = %arg5) -> (tensor<?x?x?xf32>) {
274+
// CHECK: %[[EX1:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}}: tensor<?x?x?xf32> to tensor<1x1x1xf32>
275+
// CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX1]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
276+
// CHECK: %[[EX2:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}} : tensor<?x?x?xf32> to tensor<1x1x1xf32>
277+
// CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX2]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
278+
%extracted_slice = tensor.extract_slice %0[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
279+
%extracted_slice_2 = tensor.extract_slice %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
280+
%4 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%extracted_slice : tensor<1x1x1xf32>) outs(%extracted_slice_2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
281+
%inserted_slice = tensor.insert_slice %4 into %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<?x?x?xf32>
282+
scf.yield %inserted_slice : tensor<?x?x?xf32>
283+
}
284+
scf.yield %3 : tensor<?x?x?xf32>
285+
}
286+
scf.yield %2 : tensor<?x?x?xf32>
287+
}
288+
return %1 : tensor<?x?x?xf32>
289+
}
290+
291+
module attributes {transform.with_named_sequence} {
292+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
293+
%0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg0 : (!transform.any_op) -> !transform.any_op
294+
%1 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
295+
%2:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
296+
%3:3 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
297+
transform.structured.fuse_into_containing_op %2#0 into %3#2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
298+
transform.yield
299+
}
300+
}
301+
}
302+
303+
// -----
304+
305+
#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
306+
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
307+
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
308+
209309
module {
210310
// CHECK-LABEL: func.func @fuse_tileable_multi_output_op
211311
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index

0 commit comments

Comments
 (0)