Skip to content

Commit 4fd26b6

Browse files
Address algo related comments
1 parent 045db97 commit 4fd26b6

File tree

2 files changed

+298
-112
lines changed

2 files changed

+298
-112
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 156 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Arith/Utils/Utils.h"
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1920
#include "mlir/Dialect/SCF/Utils/Utils.h"
2021
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2122
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -1173,20 +1174,8 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
11731174
unsigned resultNumber =
11741175
cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
11751176

1176-
// Check that the consumer results in exactly one value.
1177-
// TODO: Support fusion for consumers yielding more than one result.
1178-
if (consumerOp->getResults().size() != 1) {
1179-
return rewriter.notifyMatchFailure(
1180-
consumerOp,
1181-
"only those consumers returning exactly one result are supported");
1182-
}
11831177
Operation *containingOp = candidateSliceOp->getParentOp();
1184-
// Check containing op is "scf::ForOp".
11851178
auto forOp = static_cast<scf::ForOp>(containingOp);
1186-
// if (!forOp) {
1187-
// return rewriter.notifyMatchFailure(containingOp,
1188-
// "containing op is not a scf.for");
1189-
// }
11901179

11911180
OpBuilder::InsertionGuard g(rewriter);
11921181
rewriter.setInsertionPoint(candidateSliceOp);
@@ -1218,17 +1207,6 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
12181207
}
12191208

12201209
Location loc = forOp.getLoc();
1221-
SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
1222-
SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
1223-
SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
1224-
// Check all insert stride is 1.
1225-
if (llvm::any_of(strides, [](OpFoldResult stride) {
1226-
return !isConstantIntValue(stride, 1);
1227-
})) {
1228-
return rewriter.notifyMatchFailure(
1229-
candidateSliceOp, "containingOp's result yield with stride");
1230-
}
1231-
12321210
SmallVector<Value> newOuts(forOp.getInits());
12331211
newOuts.append(dpsInits);
12341212

@@ -1244,69 +1222,99 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
12441222
loopBody, newLoopBody,
12451223
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
12461224

1247-
// Clone the consumer after the insert_slice.
1225+
// 1 - Clone tensor.insert_slice after original tensor.insert_slice.
12481226
rewriter.setInsertionPointAfter(candidateSliceOp);
1249-
SmallVector<Value> newForOpBlockArgsForConsumerDest;
1250-
for (unsigned i = loopBody->getNumArguments(),
1251-
n = newLoopBody->getArguments().size();
1252-
i < n; i++) {
1253-
newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
1254-
}
1227+
SmallVector<Value> candidateSliceOpOperands =
1228+
llvm::to_vector(candidateSliceOp->getOperands());
1229+
tensor::InsertSliceOp clonedCandidateSliceOp =
1230+
mlir::clone(rewriter, candidateSliceOp,
1231+
candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1232+
1233+
// 2.a - Clone consumer after the cloned tensor.insert_slice op.
1234+
rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
1235+
SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
1236+
newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
1237+
[](BlockArgument b) -> Value { return b; });
12551238
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
12561239
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
1240+
tileableConsumer = clonedConsumerOp;
12571241

1258-
// Replace scf.for result's use in the cloned consumer with insert_slice
1259-
// result.
1242+
// 2.b - Replace all uses of the loop result with the result of the cloned
1243+
// tensor.insert_slice.
12601244
rewriter.replaceUsesWithIf(forOp.getResult(resultNumber),
1261-
candidateSliceOp.getResult(),
1245+
clonedCandidateSliceOp.getResult(),
12621246
[&](OpOperand &operand) {
12631247
return operand.getOwner() == clonedConsumerOp;
12641248
});
12651249

1266-
// Generate the tiled implementation of the consumer of the source.
1267-
rewriter.setInsertionPoint(candidateSliceOp);
1250+
// 3 - Perform tiling of the cloned consumer.
1251+
rewriter.setInsertionPointAfter(clonedConsumerOp);
12681252
FailureOr<TilingResult> tileAndFuseResult =
12691253
tensor::replaceInsertSliceWithTiledConsumer(
12701254
rewriter,
1271-
cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
1255+
cast<OffsetSizeAndStrideOpInterface>(
1256+
clonedCandidateSliceOp.getOperation()),
12721257
clonedConsumerOp->getOpOperand(operandNumber));
12731258
if (failed(tileAndFuseResult)) {
12741259
return rewriter.notifyMatchFailure(tileableConsumer,
12751260
"failed to tile consumer op: ");
12761261
}
12771262

1278-
// Update the source of the candidateSlice to be the cloned consumer.
1279-
SmallVector<Value> candidateSliceOpOperands =
1280-
llvm::to_vector(candidateSliceOp->getOperands());
1281-
candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
1282-
auto bbArgs = newforOp.getBody()->getArguments();
1283-
candidateSliceOpOperands[1] = bbArgs[1 + forOp.getInits().size() + 0];
1284-
tensor::InsertSliceOp clonedCandidateSliceOp =
1285-
mlir::clone(rewriter, candidateSliceOp,
1286-
candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1263+
// 4 - Extract offset/sizes/strides required to create the tensor.insert_slice
1264+
// for each result of the consumer.
1265+
SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
1266+
SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
1267+
SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
1268+
// Check all insert stride is 1.
1269+
if (llvm::any_of(strides, [](OpFoldResult stride) {
1270+
return !isConstantIntValue(stride, 1);
1271+
})) {
1272+
return rewriter.notifyMatchFailure(
1273+
clonedCandidateSliceOp, "containingOp's result yield with stride");
1274+
}
1275+
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1276+
// Try to get iter domain position from input position.
1277+
rewriter.setInsertionPointAfter(clonedConsumerOp);
1278+
if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
1279+
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1280+
iterDomainSizes))) {
1281+
return rewriter.notifyMatchFailure(
1282+
tileableConsumer, "can't get iter domain position from input position");
1283+
}
12871284

1288-
rewriter.replaceAllUsesWith(candidateSliceOp, candidateSliceOp.getSource());
1289-
rewriter.eraseOp(clonedConsumerOp);
1285+
// Try to get all containing op result's position from iter domain position.
1286+
llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
1287+
llvm::SmallVector<OpFoldResult>>>
1288+
resultPositions(clonedConsumerOp->getNumResults());
1289+
for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
1290+
if (failed(tileableConsumer.getResultTilePosition(
1291+
rewriter, idx, iterDomainOffsets, iterDomainSizes,
1292+
resultPositions[idx].first, resultPositions[idx].second))) {
1293+
return rewriter.notifyMatchFailure(
1294+
tileableConsumer,
1295+
"can't get result domain position from iter domain position");
1296+
}
1297+
}
12901298

1291-
// Fix terminator.
1299+
// 5 - Fix terminator.
12921300
scf::YieldOp oldTerminatorOp =
12931301
static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
1294-
// llvm::outs()<<"\n========= DB - 5 ===========\n"<<funcOp<<"\n";
1295-
1296-
SmallVector<Value> newYieldOperands;
1297-
for (Value val : oldTerminatorOp.getResults()) {
1298-
if (val == candidateSliceOp.getSource()) {
1299-
newYieldOperands.push_back(candidateSliceOp.getResult());
1300-
} else {
1301-
newYieldOperands.push_back(val);
1302-
}
1303-
}
1304-
newYieldOperands.push_back(clonedCandidateSliceOp.getResult());
1302+
SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
13051303
rewriter.setInsertionPointAfter(oldTerminatorOp);
1304+
auto bbArgs = newforOp.getBody()->getArguments();
1305+
for (auto [idx, v] :
1306+
llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
1307+
SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
1308+
rewriter.getIndexAttr(1));
1309+
newYieldOperands.push_back(rewriter.create<tensor::InsertSliceOp>(
1310+
clonedCandidateSliceOp->getLoc(), v,
1311+
bbArgs[1 + forOp.getInits().size() + idx], resultPositions[idx].first,
1312+
resultPositions[idx].second, strides));
1313+
}
13061314
rewriter.create<scf::YieldOp>(loc, newYieldOperands);
13071315
rewriter.eraseOp(oldTerminatorOp);
13081316

1309-
// Replace the result of for and consumer op.
1317+
// 6 - Replace the result of scf.for and consumer op.
13101318
for (auto result : llvm::enumerate(forOp.getResults())) {
13111319
rewriter.replaceAllUsesWith(result.value(),
13121320
newforOp->getResult(result.index()));
@@ -1318,9 +1326,12 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
13181326
newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
13191327
}
13201328

1321-
// Need to erase the old for.
1329+
rewriter.replaceOp(candidateSliceOp, clonedCandidateSliceOp);
1330+
1331+
// 7 - Need to erase the old scf.for.
13221332
rewriter.eraseOp(forOp);
13231333
rewriter.eraseOp(consumerOp);
1334+
rewriter.eraseOp(clonedConsumerOp);
13241335

13251336
return scf::SCFFuseConsumerOfSliceResult{
13261337
consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};
@@ -1373,13 +1384,7 @@ tileAndFuseConsumerOfSliceSCFForall(
13731384
unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
13741385
unsigned resultNumber =
13751386
cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
1376-
// Check that the consumer results in exactly one value.
1377-
// TODO: Support fusion for consumers yielding more than one result.
1378-
if (consumerOp->getResults().size() != 1) {
1379-
return rewriter.notifyMatchFailure(
1380-
consumerOp,
1381-
"only those consumers returning exactly one result are supported");
1382-
}
1387+
13831388
OpBuilder::InsertionGuard g(rewriter);
13841389
// Using candidateSliceOp->getParentOp() because we have the following case :-
13851390
// scf.forall.in_parallel {
@@ -1415,18 +1420,6 @@ tileAndFuseConsumerOfSliceSCFForall(
14151420
"consumer op taking the result of scf.forall as init is not supported");
14161421
}
14171422

1418-
SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
1419-
SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
1420-
SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
1421-
1422-
// Check all insert stride is 1.
1423-
if (llvm::any_of(strides, [](OpFoldResult stride) {
1424-
return !isConstantIntValue(stride, 1);
1425-
})) {
1426-
return rewriter.notifyMatchFailure(
1427-
candidateSliceOp, "containingOp's result yield with stride");
1428-
}
1429-
14301423
Location loc = forallOp.getLoc();
14311424
// Create new scf.forall op.
14321425
SmallVector<Value> newOuts(forallOp.getOutputs());
@@ -1444,49 +1437,100 @@ tileAndFuseConsumerOfSliceSCFForall(
14441437
loopBody, newLoopBody,
14451438
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
14461439

1447-
// Clone the consumer after the parallel_insert_slice.
1440+
// 1 - Clone tensor.parallel_insert_slice after the original
1441+
// tensor.parallel_insert_slice.
14481442
rewriter.setInsertionPointAfter(candidateSliceOp);
1449-
SmallVector<Value> newForOpBlockArgsForConsumerDest;
1450-
for (unsigned i = loopBody->getNumArguments(),
1451-
n = newLoopBody->getArguments().size();
1452-
i < n; i++) {
1453-
newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
1454-
}
1443+
SmallVector<Value> candidateSliceOpOperands =
1444+
llvm::to_vector(candidateSliceOp->getOperands());
1445+
tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
1446+
mlir::clone(rewriter, candidateSliceOp,
1447+
candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1448+
LLVM_DEBUG(llvm::dbgs() << "Created a clone of the candidate slice op : "
1449+
<< clonedCandidateSliceOp << "\n");
1450+
1451+
// 2 - Clone the consumer after the clone tensor.parallel_insert_slice.
1452+
rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
1453+
SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
1454+
newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
1455+
[](BlockArgument b) -> Value { return b; });
14551456
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
14561457
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
1458+
tileableConsumer = clonedConsumerOp;
14571459

1458-
// Replace scf.forall result's use in the consumer with parallel_insert_slice
1459-
// source.
1460-
rewriter.replaceAllUsesWith(forallOp.getResult(resultNumber),
1461-
candidateSliceOp.getSource());
1460+
// 2.b - Replace all uses of the scf.forall's result use in the consumer with
1461+
// the source of the cloned tensor.parallel_insert_slice.
1462+
rewriter.replaceUsesWithIf(forallOp.getResult(resultNumber),
1463+
clonedCandidateSliceOp.getSource(),
1464+
[&](OpOperand &operand) {
1465+
return operand.getOwner() == clonedConsumerOp;
1466+
});
14621467

1463-
// Generate the tiled implementation of the consumer of the source.
1464-
rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
1468+
// 3 - Perform tiling of the cloned consumer.
1469+
rewriter.setInsertionPoint(newforallOp.getTerminator());
14651470
FailureOr<TilingResult> tileAndFuseResult =
14661471
tensor::replaceInsertSliceWithTiledConsumer(
14671472
rewriter,
1468-
cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation()),
1473+
cast<OffsetSizeAndStrideOpInterface>(
1474+
clonedCandidateSliceOp.getOperation()),
14691475
clonedConsumerOp->getOpOperand(operandNumber));
14701476
if (failed(tileAndFuseResult)) {
14711477
return rewriter.notifyMatchFailure(tileableConsumer,
14721478
"failed to tile consumer op: ");
14731479
}
14741480

1475-
// Update the source of the candidateSlice to be the cloned consumer.
1476-
rewriter.setInsertionPointAfter(candidateSliceOp);
1477-
SmallVector<Value> candidateSliceOpOperands =
1478-
llvm::to_vector(candidateSliceOp->getOperands());
1479-
candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
1480-
auto bbArgs = newforallOp.getBody()->getArguments();
1481-
candidateSliceOpOperands[1] =
1482-
bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + 0];
1483-
tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
1484-
mlir::clone(rewriter, candidateSliceOp,
1485-
candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1486-
LLVM_DEBUG(llvm::dbgs() << "Created a clone of the candidate slice op : "
1487-
<< clonedCandidateSliceOp << "\n");
1481+
// 4 - Extract offset/sizes/strides required to create the
1482+
// tensor.parallel_insert_slice for each result of the consumer.
1483+
SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
1484+
SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
1485+
SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
1486+
// Check all insert stride is 1.
1487+
if (llvm::any_of(strides, [](OpFoldResult stride) {
1488+
return !isConstantIntValue(stride, 1);
1489+
})) {
1490+
return rewriter.notifyMatchFailure(
1491+
clonedCandidateSliceOp, "containingOp's result yield with stride");
1492+
}
1493+
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1494+
// Try to get iter domain position from input position.
1495+
rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
1496+
;
1497+
if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
1498+
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1499+
iterDomainSizes))) {
1500+
return rewriter.notifyMatchFailure(
1501+
tileableConsumer, "can't get iter domain position from input position");
1502+
}
14881503

1489-
rewriter.eraseOp(clonedConsumerOp);
1504+
// Try to get all containing op result's position from iter domain position.
1505+
llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
1506+
llvm::SmallVector<OpFoldResult>>>
1507+
resultPositions(clonedConsumerOp->getNumResults());
1508+
for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
1509+
if (failed(tileableConsumer.getResultTilePosition(
1510+
rewriter, idx, iterDomainOffsets, iterDomainSizes,
1511+
resultPositions[idx].first, resultPositions[idx].second))) {
1512+
return rewriter.notifyMatchFailure(
1513+
tileableConsumer,
1514+
"can't get result domain position from iter domain position");
1515+
}
1516+
}
1517+
1518+
// 5 - Fix terminator.
1519+
scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
1520+
SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
1521+
newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
1522+
Operation *firstYieldOp = yieldingOps.front();
1523+
rewriter.setInsertionPoint(firstYieldOp);
1524+
auto bbArgs = newforallOp.getBody()->getArguments();
1525+
for (auto [idx, v] :
1526+
llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
1527+
SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
1528+
rewriter.getIndexAttr(1));
1529+
rewriter.create<tensor::ParallelInsertSliceOp>(
1530+
firstYieldOp->getLoc(), v,
1531+
bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
1532+
resultPositions[idx].first, resultPositions[idx].second, strides);
1533+
}
14901534

14911535
// Replace the result of scf.forall and consumer op.
14921536
for (auto result : llvm::enumerate(forallOp.getResults())) {
@@ -1501,9 +1545,12 @@ tileAndFuseConsumerOfSliceSCFForall(
15011545
consumerResult.index()));
15021546
}
15031547

1504-
// Need to erase the old scf.forall and consumer.
1548+
// Need to erase the old scf.forall, consumer, cloned consumer and
1549+
// candidateSliceOp.
15051550
rewriter.eraseOp(forallOp);
15061551
rewriter.eraseOp(consumerOp);
1552+
rewriter.eraseOp(clonedConsumerOp);
1553+
rewriter.eraseOp(candidateSliceOp);
15071554

15081555
return scf::SCFFuseConsumerOfSliceResult{
15091556
consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};

0 commit comments

Comments
 (0)