Skip to content

Commit 7e9f0b5

Browse files
Better algo by Mahesh
1 parent 86b83cf commit 7e9f0b5

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,74 +1386,76 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
13861386
rewriter.mergeBlocks(oldLoopBody, newLoopBody,
13871387
newLoopBody->getArguments().take_front(oldNumArguments));
13881388

1389-
// 5.a. Clone consumer after the cloned
1390-
// tensor.insert_slice/parallel_insert_slice op.
1391-
rewriter.setInsertionPointAfter(candidateSliceOp);
1389+
// 5. Set insertion point before terminator op of the loop and create a new
1390+
// tensor.insert_slice. In the scf.for case this is a clone of the
1391+
// candidateSliceOp whereas in the scf.forall case this is created from the
1392+
// operands of tensor.parallel_insert_slice.
1393+
tensor::InsertSliceOp clonedInsertSliceOp;
1394+
if (auto sliceOp =
1395+
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1396+
auto newForallOp = cast<scf::ForallOp>(newLoopOp);
1397+
rewriter.setInsertionPoint(newForallOp.getTerminator());
1398+
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1399+
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1400+
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1401+
} else {
1402+
auto newForOp = cast<scf::ForOp>(newLoopOp);
1403+
rewriter.setInsertionPoint(newForOp.getBody()->getTerminator());
1404+
clonedInsertSliceOp =
1405+
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
1406+
}
1407+
1408+
// 6.a. Clone consumer op.
13921409
auto newForOpBlockArgsForConsumerDest =
13931410
newLoopBody->getArguments().drop_front(oldNumArguments);
13941411
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
13951412
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
13961413

1397-
// 5.b. Replace all uses of the loop result with the result of the cloned
1398-
// tensor.insert_slice/parallel_insert_slice.
1414+
// 6.b. Replace all uses of the loop result with the result of the cloned
1415+
// tensor.insert_slice.
13991416
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
14001417
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
1401-
if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
1402-
operandToReplace.set(sliceOp.getResult());
1403-
} else if (auto sliceOp =
1404-
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1405-
operandToReplace.set(sliceOp.getSource());
1406-
}
1418+
operandToReplace.set(clonedInsertSliceOp.getResult());
14071419
});
14081420

1409-
// 6 - Perform tiling of the cloned consumer and replace the OpOperand that's
1410-
// already tiled.
1411-
if (isInsertSliceOp) {
1412-
rewriter.setInsertionPointAfter(clonedConsumerOp);
1413-
} else {
1414-
rewriter.setInsertionPoint(cast<scf::ForallOp>(newLoopOp).getTerminator());
1415-
}
1416-
auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
1421+
// 7 - Perform tiling of the cloned consumer and replace the operand at
1422+
// `operandNumber` with the source of the cloned tensor.insert_slice op.
1423+
auto ossSliceOp =
1424+
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
14171425
FailureOr<TilingResult> tileAndFuseResult =
14181426
tensor::replaceInsertSliceWithTiledConsumer(
14191427
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
14201428
if (failed(tileAndFuseResult)) {
14211429
return failure();
14221430
}
1423-
tileAndFuseResult->tiledOps[0]
1424-
->getOpOperand(operandNumber)
1425-
.set(candidateSliceOp->getOperand(0));
1431+
rewriter.replaceAllUsesWith(
1432+
tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
1433+
clonedInsertSliceOp.getSource());
14261434

1427-
// 7 - Extract offset/sizes/strides required to create the
1435+
// 8 - Extract offset/sizes/strides required to create the
14281436
// tensor.insert_slice/parallel_insert_slice for each result of the consumer.
14291437
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
14301438
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
14311439
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
14321440

1433-
// 8. Check all insert stride is 1.
1441+
// 9. Check all insert stride is 1.
14341442
if (llvm::any_of(strides, [](OpFoldResult stride) {
14351443
return !isConstantIntValue(stride, 1);
14361444
})) {
14371445
return rewriter.notifyMatchFailure(
14381446
candidateSliceOp, "containingOp's result yield with stride");
14391447
}
14401448

1441-
// 9. Try to get iter domain position from input position.
1449+
// 10. Try to get iter domain position from input position.
14421450
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1443-
1444-
if (isInsertSliceOp) {
1445-
rewriter.setInsertionPointAfter(clonedConsumerOp);
1446-
} else {
1447-
rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
1448-
}
14491451
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
14501452
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
14511453
iterDomainSizes))) {
14521454
return rewriter.notifyMatchFailure(
14531455
clonedConsumerOp, "can't get iter domain position from input position");
14541456
}
14551457

1456-
// 10. Try to fetch the offset and size for all results of the cloned
1458+
// 11. Try to fetch the offset and size for all results of the cloned
14571459
// consumer. This would then be used to form the corresponding
14581460
// tensor.insert_slice/parallel_insert_slice later.
14591461
unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();

0 commit comments

Comments
 (0)