@@ -1386,74 +1386,76 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1386
1386
rewriter.mergeBlocks (oldLoopBody, newLoopBody,
1387
1387
newLoopBody->getArguments ().take_front (oldNumArguments));
1388
1388
1389
- // 5.a. Clone consumer after the cloned
1390
- // tensor.insert_slice/parallel_insert_slice op.
1391
- rewriter.setInsertionPointAfter (candidateSliceOp);
1389
+ // 5. Set insertion point before terminator op of the loop and create a new
1390
+ // tensor.insert_slice. In the scf.for case this is a clone of the
1391
+ // candidateSliceOp whereas in the scf.forall case this is created from the
1392
+ // operands of tensor.parallel_insert_slice.
1393
+ tensor::InsertSliceOp clonedInsertSliceOp;
1394
+ if (auto sliceOp =
1395
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1396
+ auto newForallOp = cast<scf::ForallOp>(newLoopOp);
1397
+ rewriter.setInsertionPoint (newForallOp.getTerminator ());
1398
+ clonedInsertSliceOp = rewriter.create <tensor::InsertSliceOp>(
1399
+ loc, sliceOp.getSource (), sliceOp.getDest (), sliceOp.getMixedOffsets (),
1400
+ sliceOp.getMixedSizes (), sliceOp.getMixedStrides ());
1401
+ } else {
1402
+ auto newForOp = cast<scf::ForOp>(newLoopOp);
1403
+ rewriter.setInsertionPoint (newForOp.getBody ()->getTerminator ());
1404
+ clonedInsertSliceOp =
1405
+ cast<tensor::InsertSliceOp>(rewriter.clone (*candidateSliceOp));
1406
+ }
1407
+
1408
+ // 6.a. Clone consumer op.
1392
1409
auto newForOpBlockArgsForConsumerDest =
1393
1410
newLoopBody->getArguments ().drop_front (oldNumArguments);
1394
1411
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs (
1395
1412
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
1396
1413
1397
- // 5 .b. Replace all uses of the loop result with the result of the cloned
1398
- // tensor.insert_slice/parallel_insert_slice .
1414
+ // 6 .b. Replace all uses of the loop result with the result of the cloned
1415
+ // tensor.insert_slice.
1399
1416
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand (operandNumber);
1400
1417
rewriter.modifyOpInPlace (clonedConsumerOp, [&]() {
1401
- if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
1402
- operandToReplace.set (sliceOp.getResult ());
1403
- } else if (auto sliceOp =
1404
- dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1405
- operandToReplace.set (sliceOp.getSource ());
1406
- }
1418
+ operandToReplace.set (clonedInsertSliceOp.getResult ());
1407
1419
});
1408
1420
1409
- // 6 - Perform tiling of the cloned consumer and replace the OpOperand that's
1410
- // already tiled.
1411
- if (isInsertSliceOp) {
1412
- rewriter.setInsertionPointAfter (clonedConsumerOp);
1413
- } else {
1414
- rewriter.setInsertionPoint (cast<scf::ForallOp>(newLoopOp).getTerminator ());
1415
- }
1416
- auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp);
1421
+ // 7 - Perform tiling of the cloned consumer and replace the operand at
1422
+ // `operandNumber` with the source of the cloned tensor.insert_slice op.
1423
+ auto ossSliceOp =
1424
+ cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation ());
1417
1425
FailureOr<TilingResult> tileAndFuseResult =
1418
1426
tensor::replaceInsertSliceWithTiledConsumer (
1419
1427
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand (operandNumber));
1420
1428
if (failed (tileAndFuseResult)) {
1421
1429
return failure ();
1422
1430
}
1423
- tileAndFuseResult-> tiledOps [ 0 ]
1424
- -> getOpOperand (operandNumber)
1425
- . set (candidateSliceOp-> getOperand ( 0 ));
1431
+ rewriter. replaceAllUsesWith (
1432
+ tileAndFuseResult-> tiledOps [ 0 ]-> getOperand (operandNumber),
1433
+ clonedInsertSliceOp. getSource ( ));
1426
1434
1427
- // 7 - Extract offset/sizes/strides required to create the
1435
+ // 8 - Extract offset/sizes/strides required to create the
1428
1436
// tensor.insert_slice/parallel_insert_slice for each result of the consumer.
1429
1437
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets ();
1430
1438
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes ();
1431
1439
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides ();
1432
1440
1433
- // 8 . Check all insert stride is 1.
1441
+ // 9 . Check all insert stride is 1.
1434
1442
if (llvm::any_of (strides, [](OpFoldResult stride) {
1435
1443
return !isConstantIntValue (stride, 1 );
1436
1444
})) {
1437
1445
return rewriter.notifyMatchFailure (
1438
1446
candidateSliceOp, " containingOp's result yield with stride" );
1439
1447
}
1440
1448
1441
- // 9 . Try to get iter domain position from input position.
1449
+ // 10 . Try to get iter domain position from input position.
1442
1450
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1443
-
1444
- if (isInsertSliceOp) {
1445
- rewriter.setInsertionPointAfter (clonedConsumerOp);
1446
- } else {
1447
- rewriter.setInsertionPointAfter (tileAndFuseResult->tiledOps [0 ]);
1448
- }
1449
1451
if (failed (clonedConsumerOp.getIterationDomainTileFromOperandTile (
1450
1452
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1451
1453
iterDomainSizes))) {
1452
1454
return rewriter.notifyMatchFailure (
1453
1455
clonedConsumerOp, " can't get iter domain position from input position" );
1454
1456
}
1455
1457
1456
- // 10 . Try to fetch the offset and size for all results of the cloned
1458
+ // 11 . Try to fetch the offset and size for all results of the cloned
1457
1459
// consumer. This would then be used to form the corresponding
1458
1460
// tensor.insert_slice/parallel_insert_slice later.
1459
1461
unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults ();
0 commit comments