@@ -1147,12 +1147,24 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1147
1147
// / failure otherwise.
1148
1148
static FailureOr<OpOperand *> getConsumerFromUses (Value val,
1149
1149
Block *containingOpBlock) {
1150
- // Step 1. Check that the value has exactly one use.
1151
- if (!llvm::hasSingleElement (val.getUses ()))
1150
+ // Step 1. Check that the value has exactly one use except for scf.yield.
1151
+ OpOperand *operand = nullptr ;
1152
+ for (auto &use : val.getUses ()) {
1153
+ Operation *user = use.getOwner ();
1154
+ if (isa<tensor::InsertSliceOp>(user) ||
1155
+ isa<tensor::ParallelInsertSliceOp>(user))
1156
+ continue ;
1157
+ else {
1158
+ if (operand)
1159
+ return failure ();
1160
+ else
1161
+ operand = &use;
1162
+ }
1163
+ }
1164
+ if (!operand)
1152
1165
return failure ();
1153
1166
// Step 2. Get uses.
1154
- OpOperand &operand = (*val.getUses ().begin ());
1155
- Operation *consumerOp = operand.getOwner ();
1167
+ Operation *consumerOp = operand->getOwner ();
1156
1168
// TODO: We have to init result of consumer before scf.for, use
1157
1169
// DestinationStyleOpInterface to get result shape from init for now.
1158
1170
// Add support for other op such as op has InferTypeOpInterface.
@@ -1161,7 +1173,22 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1161
1173
return failure ();
1162
1174
if (containingOpBlock != consumerOp->getBlock ())
1163
1175
return failure ();
1164
- return &operand;
1176
+ return operand;
1177
+ }
1178
+
1179
+ // / Return perfectly outer loops of given ForOp(included), sorted from
1180
+ // / outer to inner.
1181
+ static SmallVector<scf::ForOp> getPerfectlyOuterLoops (scf::ForOp loop) {
1182
+ SmallVector<scf::ForOp> outerLoops = {loop};
1183
+ auto forOp = loop->getParentOfType <scf::ForOp>();
1184
+ while (forOp) {
1185
+ Block &body = forOp.getRegion ().front ();
1186
+ if (body.begin () != std::prev (body.end (), 2 ))
1187
+ break ;
1188
+ outerLoops.push_back (forOp);
1189
+ forOp = forOp->getParentOfType <scf::ForOp>();
1190
+ }
1191
+ return {outerLoops.rbegin (), outerLoops.rend ()};
1165
1192
}
1166
1193
1167
1194
// / Fetch the untiled consumer of a scf.for's result which is yielded by a
@@ -1181,9 +1208,10 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1181
1208
auto forOp = dyn_cast<scf::ForOp>(containingOp);
1182
1209
if (!forOp)
1183
1210
return failure ();
1184
- Value resultingValue = forOp->getResult (resultNumber);
1211
+ scf::ForOp topLevelForOp = getPerfectlyOuterLoops (forOp).front ();
1212
+ Value resultingValue = topLevelForOp->getResult (resultNumber);
1185
1213
1186
- return getConsumerFromUses (resultingValue, containingOp ->getBlock ());
1214
+ return getConsumerFromUses (resultingValue, topLevelForOp ->getBlock ());
1187
1215
}
1188
1216
1189
1217
// / Fetch the first untiled consumer of a scf.forall's result which is yielded
@@ -1302,8 +1330,8 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
1302
1330
// / Implementation of fusing consumer of a single slice by computing the
1303
1331
// / slice of the consumer in-place for scf loop.
1304
1332
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1305
- mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
1306
- Operation *candidateSliceOp) {
1333
+ mlir::scf::tileAndFuseConsumerOfSliceImpl (RewriterBase &rewriter,
1334
+ Operation *candidateSliceOp) {
1307
1335
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1308
1336
candidateSliceOp))
1309
1337
return failure ();
@@ -1337,22 +1365,25 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1337
1365
if (isInsertSliceOp) {
1338
1366
auto forOp = candidateSliceOp->getParentOfType <scf::ForOp>();
1339
1367
oldLoopOp = forOp;
1340
- llvm::append_range (newOuts, forOp.getInits ());
1341
- oldLoopBody = forOp.getBody ();
1342
1368
initSize = forOp.getInits ().size ();
1343
1369
} else {
1344
1370
auto forallOp = candidateSliceOp->getParentOfType <scf::ForallOp>();
1345
1371
oldLoopOp = forallOp;
1346
- llvm::append_range (newOuts, forallOp.getOutputs ());
1347
- oldLoopBody = forallOp.getBody ();
1348
1372
initSize = forallOp.getOutputs ().size ();
1349
1373
rank = forallOp.getRank ();
1350
1374
}
1351
1375
1352
- if (failed (checkAssumptionForLoop (oldLoopOp, consumerOp))) {
1376
+ Operation *oldTopLevelLoop = oldLoopOp;
1377
+ SmallVector<scf::ForOp> oldNestedForOps, newNestedForOps;
1378
+ if (isInsertSliceOp) {
1379
+ oldNestedForOps = getPerfectlyOuterLoops (cast<scf::ForOp>(oldLoopOp));
1380
+ oldTopLevelLoop = oldNestedForOps.front ();
1381
+ }
1382
+ if (failed (checkAssumptionForLoop (oldTopLevelLoop, consumerOp))) {
1353
1383
return rewriter.notifyMatchFailure (
1354
- oldLoopOp, " containing loop op should either yield just one value or "
1355
- " have the consumer op as its first user" );
1384
+ oldTopLevelLoop,
1385
+ " containing loop op should either yield just one value or "
1386
+ " have the consumer op as its first user" );
1356
1387
}
1357
1388
1358
1389
OpBuilder::InsertionGuard g (rewriter);
@@ -1361,28 +1392,59 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1361
1392
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
1362
1393
SmallVector<Value> dpsInits =
1363
1394
llvm::map_to_vector (dstOp.getDpsInits (), [](Value v) { return v; });
1364
- if (llvm::is_contained (dpsInits, oldLoopOp ->getResult (resultNumber))) {
1395
+ if (llvm::is_contained (dpsInits, oldTopLevelLoop ->getResult (resultNumber))) {
1365
1396
return rewriter.notifyMatchFailure (
1366
1397
consumerOp,
1367
1398
" consumer op taking the result of scf.for as init is not supported" );
1368
1399
}
1369
- newOuts. append ( dpsInits) ;
1400
+ SmallVector<Value> newInitAppend = dpsInits;
1370
1401
1371
1402
Location loc = oldLoopOp->getLoc ();
1372
1403
1373
1404
// 3. Create new scf loop op.
1374
1405
rewriter.setInsertionPoint (consumerOp);
1406
+
1407
+ // 3.a Create new outer scf loops if necessary
1408
+ bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size () > 1 ;
1409
+ if (isNestedForOps) {
1410
+ for (auto &forOp : MutableArrayRef (oldNestedForOps).drop_back ()) {
1411
+ SmallVector<Value> newInits;
1412
+ newInits = llvm::to_vector (forOp.getInits ());
1413
+ newInits.append (newInitAppend.begin (), newInitAppend.end ());
1414
+ auto newLoop = rewriter.create <scf::ForOp>(
1415
+ forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1416
+ forOp.getStep (), newInits);
1417
+ newInitAppend = llvm::map_to_vector (
1418
+ newLoop.getRegionIterArgs ().take_back (newInitAppend.size ()),
1419
+ [](BlockArgument bArg) -> Value { return bArg; });
1420
+ rewriter.mergeBlocks (
1421
+ forOp.getBody (), newLoop.getBody (),
1422
+ newLoop.getBody ()->getArguments ().take_front (initSize + 1 ));
1423
+ rewriter.replaceOp (
1424
+ forOp, newLoop->getResults ().take_front (forOp->getNumResults ()));
1425
+ newNestedForOps.push_back (newLoop);
1426
+ }
1427
+ rewriter.setInsertionPoint (oldNestedForOps.back ());
1428
+ }
1429
+
1430
+ // 3.b Create new inner most scf loop
1375
1431
Operation *newLoopOp = nullptr ;
1376
1432
Block *newLoopBody = nullptr ;
1377
1433
if (isInsertSliceOp) {
1378
1434
auto forOp = cast<scf::ForOp>(oldLoopOp);
1435
+ llvm::append_range (newOuts, forOp.getInits ());
1436
+ newOuts.append (newInitAppend);
1437
+ oldLoopBody = forOp.getBody ();
1379
1438
auto newForOp = rewriter.create <scf::ForOp>(loc, forOp.getLowerBound (),
1380
1439
forOp.getUpperBound (),
1381
1440
forOp.getStep (), newOuts);
1382
1441
newLoopOp = newForOp;
1383
1442
newLoopBody = newForOp.getBody ();
1384
1443
} else {
1385
1444
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1445
+ llvm::append_range (newOuts, forallOp.getOutputs ());
1446
+ newOuts.append (newInitAppend);
1447
+ oldLoopBody = forallOp.getBody ();
1386
1448
auto newForallOp = rewriter.create <scf::ForallOp>(
1387
1449
loc, forallOp.getMixedLowerBound (), forallOp.getMixedUpperBound (),
1388
1450
forallOp.getMixedStep (), newOuts, forallOp.getMapping ());
@@ -1496,28 +1558,155 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1496
1558
newForallOp.getBody ()->getArguments ().drop_front (rank + initSize));
1497
1559
}
1498
1560
1499
- // 12. Replace the result of scf loop and consumer op with new loop's results.
1561
+ // 12. Restore outer loops from inner to outer
1562
+ if (isNestedForOps) {
1563
+ newNestedForOps.push_back (cast<scf::ForOp>(newLoopOp));
1564
+ for (auto [outerLoop, innerLoop] :
1565
+ llvm::zip_equal (MutableArrayRef (newNestedForOps).drop_back (),
1566
+ MutableArrayRef (newNestedForOps).drop_front ())) {
1567
+ auto forOp = cast<scf::ForOp>(outerLoop);
1568
+ auto outerLoopYield =
1569
+ cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
1570
+ SmallVector<Value> newYields =
1571
+ llvm::to_vector (outerLoopYield.getOperands ());
1572
+ ValueRange additionalYields =
1573
+ innerLoop->getResults ().take_back (newInitAppend.size ());
1574
+ newYields.append (additionalYields.begin (), additionalYields.end ());
1575
+ rewriter.setInsertionPoint (outerLoopYield);
1576
+ rewriter.replaceOpWithNewOp <scf::YieldOp>(outerLoopYield, newYields);
1577
+ }
1578
+ }
1579
+
1580
+ // 13. Replace the result of scf loop and consumer op with new loop's results.
1500
1581
for (auto &&[oldResult, newResult] :
1501
1582
llvm::zip_first (oldLoopOp->getResults (), newLoopOp->getResults ())) {
1502
1583
rewriter.replaceAllUsesWith (oldResult, newResult);
1503
1584
}
1504
1585
1586
+ Operation *newTopLevelLoop =
1587
+ isNestedForOps ? newNestedForOps.front () : newLoopOp;
1505
1588
for (auto &&[oldResult, newResult] :
1506
1589
llvm::zip (consumerOp->getResults (),
1507
- newLoopOp ->getResults ().drop_front (initSize))) {
1590
+ newTopLevelLoop ->getResults ().drop_front (initSize))) {
1508
1591
rewriter.replaceAllUsesWith (oldResult, newResult);
1509
1592
}
1510
1593
1511
- // 13 . Need to erase the old scf loop and the cloned consumer op.
1594
+ // 14 . Need to erase the old scf loop and the cloned consumer op.
1512
1595
rewriter.eraseOp (oldLoopOp);
1513
1596
rewriter.eraseOp (clonedConsumerOp);
1514
1597
1598
+ // 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in
1599
+ // avoid of complex domination analysis
1600
+ assert (clonedInsertSliceOp->hasOneUse ());
1601
+ auto unUsedExtractOp =
1602
+ cast<tensor::ExtractSliceOp>((*clonedInsertSliceOp->getUsers ().begin ()));
1603
+ rewriter.eraseOp (unUsedExtractOp);
1604
+ rewriter.eraseOp (clonedInsertSliceOp);
1605
+
1515
1606
return scf::SCFFuseConsumerOfSliceResult{
1516
1607
consumerOpOperand,
1517
1608
&(tileAndFuseResult->tiledOps [0 ]->getOpOperand (operandNumber)),
1518
1609
tileAndFuseResult->tiledOps };
1519
1610
}
1520
1611
1612
+ // / Get the result of top-level loop which yields the target InsertSliceOp. E.g
1613
+ // / ```
1614
+ // / %1 = scf.for
1615
+ // / %2 = scf.for
1616
+ // / %3 = scf.for
1617
+ // / ...
1618
+ // / %4 = insert
1619
+ // / yield %4
1620
+ // / %5 = insert %3
1621
+ // / yield %5
1622
+ // / yield %2
1623
+ // / ```
1624
+ // / @param targetSliceOp: %4 = insert
1625
+ // / @param insertSliceOpChain: chain of all related insert sliceOp
1626
+ // / @return resultValue: %1
1627
+ static FailureOr<Value> getResultOfTopLevelLoopYieldInsertSliceOp (
1628
+ Operation *targetSliceOp,
1629
+ SmallVectorImpl<OffsetSizeAndStrideOpInterface> &insertSliceOpChain,
1630
+ int curDepth = 0 , int maxDepth = 5 ) {
1631
+ assert (isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
1632
+ // Control recursive time in avoid of stack overflow
1633
+ if (curDepth > maxDepth)
1634
+ return failure ();
1635
+
1636
+ insertSliceOpChain.push_back (
1637
+ cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
1638
+ Value resultOfLoop;
1639
+ if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
1640
+ Value destValue = sliceOp.getDest ();
1641
+ auto iterArg = cast<BlockArgument>(destValue);
1642
+ auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner ()->getParentOp ());
1643
+ if (!forallOp)
1644
+ return failure ();
1645
+ resultOfLoop = forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg));
1646
+ } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
1647
+ Value resultValue = sliceOp.getResult ();
1648
+ for (auto &useOperand : resultValue.getUses ()) {
1649
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner ())) {
1650
+ if (llvm::detail::isPresent (resultOfLoop))
1651
+ return failure ();
1652
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ());
1653
+ if (!forOp)
1654
+ return failure ();
1655
+ resultOfLoop = forOp->getResult (useOperand.getOperandNumber ());
1656
+ }
1657
+ }
1658
+ }
1659
+
1660
+ if (!llvm::detail::isPresent (resultOfLoop))
1661
+ return failure ();
1662
+
1663
+ while (true ) {
1664
+ bool walkThroughOuterLoop = false ;
1665
+ for (OpOperand &useOperand : resultOfLoop.getUses ()) {
1666
+ if (auto sliceOp =
1667
+ dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner ())) {
1668
+ return getResultOfTopLevelLoopYieldInsertSliceOp (
1669
+ sliceOp, insertSliceOpChain, curDepth + 1 );
1670
+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner ())) {
1671
+ // walk through outer loop
1672
+ auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ());
1673
+ if (!forOp)
1674
+ return failure ();
1675
+ resultOfLoop = forOp->getResult (useOperand.getOperandNumber ());
1676
+ walkThroughOuterLoop = true ;
1677
+ break ;
1678
+ }
1679
+ }
1680
+ if (!walkThroughOuterLoop)
1681
+ break ;
1682
+ }
1683
+ return resultOfLoop;
1684
+ }
1685
+
1686
+ // / Fusing real consumer of a single slice even within complex nested loops via
1687
+ // / multiple application of `tileAndFuseConsumerOfSliceImpl`.
1688
+ FailureOr<scf::SCFFuseConsumerOfSliceResult>
1689
+ mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
1690
+ Operation *candidateSliceOp) {
1691
+ SmallVector<OffsetSizeAndStrideOpInterface> sliceOpChain;
1692
+ if (failed (getResultOfTopLevelLoopYieldInsertSliceOp (candidateSliceOp,
1693
+ sliceOpChain)))
1694
+ return failure ();
1695
+
1696
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
1697
+ // reverse from outer to inner
1698
+ std::reverse (sliceOpChain.begin (), sliceOpChain.end ());
1699
+ // multiple application of `tileAndFuseConsumerOfSliceImpl`
1700
+ for (auto &sliceOp : sliceOpChain) {
1701
+ fuseConsumerResult = tileAndFuseConsumerOfSliceImpl (rewriter, sliceOp);
1702
+ if (failed (fuseConsumerResult)) {
1703
+ return rewriter.notifyMatchFailure (sliceOp,
1704
+ " could not fuse consumer of sliceOp" );
1705
+ }
1706
+ }
1707
+ return fuseConsumerResult;
1708
+ }
1709
+
1521
1710
// ===----------------------------------------------------------------------===//
1522
1711
// lowerToLoopsUsingSCFForOp implementation.
1523
1712
// ===----------------------------------------------------------------------===//
0 commit comments