16
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
17
#include " mlir/Dialect/Arith/Utils/Utils.h"
18
18
#include " mlir/Dialect/Func/IR/FuncOps.h"
19
+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
19
20
#include " mlir/Dialect/SCF/Utils/Utils.h"
20
21
#include " mlir/Dialect/Tensor/IR/Tensor.h"
21
22
#include " mlir/Dialect/Utils/IndexingUtils.h"
@@ -1173,20 +1174,8 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
1173
1174
unsigned resultNumber =
1174
1175
cast<OpResult>((*consumerOpOperand)->get ()).getResultNumber ();
1175
1176
1176
- // Check that the consumer results in exactly one value.
1177
- // TODO: Support fusion for consumers yielding more than one result.
1178
- if (consumerOp->getResults ().size () != 1 ) {
1179
- return rewriter.notifyMatchFailure (
1180
- consumerOp,
1181
- " only those consumers returning exactly one result are supported" );
1182
- }
1183
1177
Operation *containingOp = candidateSliceOp->getParentOp ();
1184
- // Check containing op is "scf::ForOp".
1185
1178
auto forOp = static_cast <scf::ForOp>(containingOp);
1186
- // if (!forOp) {
1187
- // return rewriter.notifyMatchFailure(containingOp,
1188
- // "containing op is not a scf.for");
1189
- // }
1190
1179
1191
1180
OpBuilder::InsertionGuard g (rewriter);
1192
1181
rewriter.setInsertionPoint (candidateSliceOp);
@@ -1218,17 +1207,6 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
1218
1207
}
1219
1208
1220
1209
Location loc = forOp.getLoc ();
1221
- SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets ();
1222
- SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes ();
1223
- SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides ();
1224
- // Check all insert stride is 1.
1225
- if (llvm::any_of (strides, [](OpFoldResult stride) {
1226
- return !isConstantIntValue (stride, 1 );
1227
- })) {
1228
- return rewriter.notifyMatchFailure (
1229
- candidateSliceOp, " containingOp's result yield with stride" );
1230
- }
1231
-
1232
1210
SmallVector<Value> newOuts (forOp.getInits ());
1233
1211
newOuts.append (dpsInits);
1234
1212
@@ -1244,69 +1222,99 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
1244
1222
loopBody, newLoopBody,
1245
1223
newLoopBody->getArguments ().take_front (loopBody->getNumArguments ()));
1246
1224
1247
- // Clone the consumer after the insert_slice.
1225
+ // 1 - Clone tensor.insert_slice after original tensor. insert_slice.
1248
1226
rewriter.setInsertionPointAfter (candidateSliceOp);
1249
- SmallVector<Value> newForOpBlockArgsForConsumerDest;
1250
- for (unsigned i = loopBody->getNumArguments (),
1251
- n = newLoopBody->getArguments ().size ();
1252
- i < n; i++) {
1253
- newForOpBlockArgsForConsumerDest.push_back (newLoopBody->getArgument (i));
1254
- }
1227
+ SmallVector<Value> candidateSliceOpOperands =
1228
+ llvm::to_vector (candidateSliceOp->getOperands ());
1229
+ tensor::InsertSliceOp clonedCandidateSliceOp =
1230
+ mlir::clone (rewriter, candidateSliceOp,
1231
+ candidateSliceOp->getResultTypes (), candidateSliceOpOperands);
1232
+
1233
+ // 2.a - Clone consumer after the cloned tensor.insert_slice op.
1234
+ rewriter.setInsertionPointAfter (clonedCandidateSliceOp);
1235
+ SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector (
1236
+ newLoopBody->getArguments ().drop_front (loopBody->getNumArguments ()),
1237
+ [](BlockArgument b) -> Value { return b; });
1255
1238
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs (
1256
1239
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
1240
+ tileableConsumer = clonedConsumerOp;
1257
1241
1258
- // Replace scf.for result's use in the cloned consumer with insert_slice
1259
- // result .
1242
+ // 2.b - Replace all uses of the loop result with the result of the cloned
1243
+ // tensor.insert_slice .
1260
1244
rewriter.replaceUsesWithIf (forOp.getResult (resultNumber),
1261
- candidateSliceOp .getResult (),
1245
+ clonedCandidateSliceOp .getResult (),
1262
1246
[&](OpOperand &operand) {
1263
1247
return operand.getOwner () == clonedConsumerOp;
1264
1248
});
1265
1249
1266
- // Generate the tiled implementation of the consumer of the source .
1267
- rewriter.setInsertionPoint (candidateSliceOp );
1250
+ // 3 - Perform tiling of the cloned consumer .
1251
+ rewriter.setInsertionPointAfter (clonedConsumerOp );
1268
1252
FailureOr<TilingResult> tileAndFuseResult =
1269
1253
tensor::replaceInsertSliceWithTiledConsumer (
1270
1254
rewriter,
1271
- cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation ()),
1255
+ cast<OffsetSizeAndStrideOpInterface>(
1256
+ clonedCandidateSliceOp.getOperation ()),
1272
1257
clonedConsumerOp->getOpOperand (operandNumber));
1273
1258
if (failed (tileAndFuseResult)) {
1274
1259
return rewriter.notifyMatchFailure (tileableConsumer,
1275
1260
" failed to tile consumer op: " );
1276
1261
}
1277
1262
1278
- // Update the source of the candidateSlice to be the cloned consumer.
1279
- SmallVector<Value> candidateSliceOpOperands =
1280
- llvm::to_vector (candidateSliceOp->getOperands ());
1281
- candidateSliceOpOperands[0 ] = tileAndFuseResult->tiledValues [0 ];
1282
- auto bbArgs = newforOp.getBody ()->getArguments ();
1283
- candidateSliceOpOperands[1 ] = bbArgs[1 + forOp.getInits ().size () + 0 ];
1284
- tensor::InsertSliceOp clonedCandidateSliceOp =
1285
- mlir::clone (rewriter, candidateSliceOp,
1286
- candidateSliceOp->getResultTypes (), candidateSliceOpOperands);
1263
+ // 4 - Extract offset/sizes/strides required to create the tensor.insert_slice
1264
+ // for each result of the consumer.
1265
+ SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets ();
1266
+ SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes ();
1267
+ SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides ();
1268
+ // Check all insert stride is 1.
1269
+ if (llvm::any_of (strides, [](OpFoldResult stride) {
1270
+ return !isConstantIntValue (stride, 1 );
1271
+ })) {
1272
+ return rewriter.notifyMatchFailure (
1273
+ clonedCandidateSliceOp, " containingOp's result yield with stride" );
1274
+ }
1275
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1276
+ // Try to get iter domain position from input position.
1277
+ rewriter.setInsertionPointAfter (clonedConsumerOp);
1278
+ if (failed (tileableConsumer.getIterDomainTilePositionFromOperandPosition (
1279
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1280
+ iterDomainSizes))) {
1281
+ return rewriter.notifyMatchFailure (
1282
+ tileableConsumer, " can't get iter domain position from input position" );
1283
+ }
1287
1284
1288
- rewriter.replaceAllUsesWith (candidateSliceOp, candidateSliceOp.getSource ());
1289
- rewriter.eraseOp (clonedConsumerOp);
1285
+ // Try to get all containing op result's position from iter domain position.
1286
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
1287
+ llvm::SmallVector<OpFoldResult>>>
1288
+ resultPositions (clonedConsumerOp->getNumResults ());
1289
+ for (auto [idx, v] : llvm::enumerate (clonedConsumerOp->getResults ())) {
1290
+ if (failed (tileableConsumer.getResultTilePosition (
1291
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
1292
+ resultPositions[idx].first , resultPositions[idx].second ))) {
1293
+ return rewriter.notifyMatchFailure (
1294
+ tileableConsumer,
1295
+ " can't get result domain position from iter domain position" );
1296
+ }
1297
+ }
1290
1298
1291
- // Fix terminator.
1299
+ // 5 - Fix terminator.
1292
1300
scf::YieldOp oldTerminatorOp =
1293
1301
static_cast <scf::YieldOp>(newforOp.getBody ()->getTerminator ());
1294
- // llvm::outs()<<"\n========= DB - 5 ===========\n"<<funcOp<<"\n";
1295
-
1296
- SmallVector<Value> newYieldOperands;
1297
- for (Value val : oldTerminatorOp.getResults ()) {
1298
- if (val == candidateSliceOp.getSource ()) {
1299
- newYieldOperands.push_back (candidateSliceOp.getResult ());
1300
- } else {
1301
- newYieldOperands.push_back (val);
1302
- }
1303
- }
1304
- newYieldOperands.push_back (clonedCandidateSliceOp.getResult ());
1302
+ SmallVector<Value> newYieldOperands (oldTerminatorOp.getResults ());
1305
1303
rewriter.setInsertionPointAfter (oldTerminatorOp);
1304
+ auto bbArgs = newforOp.getBody ()->getArguments ();
1305
+ for (auto [idx, v] :
1306
+ llvm::enumerate (tileAndFuseResult->tiledOps [0 ]->getResults ())) {
1307
+ SmallVector<OpFoldResult> strides (resultPositions[idx].first .size (),
1308
+ rewriter.getIndexAttr (1 ));
1309
+ newYieldOperands.push_back (rewriter.create <tensor::InsertSliceOp>(
1310
+ clonedCandidateSliceOp->getLoc (), v,
1311
+ bbArgs[1 + forOp.getInits ().size () + idx], resultPositions[idx].first ,
1312
+ resultPositions[idx].second , strides));
1313
+ }
1306
1314
rewriter.create <scf::YieldOp>(loc, newYieldOperands);
1307
1315
rewriter.eraseOp (oldTerminatorOp);
1308
1316
1309
- // Replace the result of for and consumer op.
1317
+ // 6 - Replace the result of scf. for and consumer op.
1310
1318
for (auto result : llvm::enumerate (forOp.getResults ())) {
1311
1319
rewriter.replaceAllUsesWith (result.value (),
1312
1320
newforOp->getResult (result.index ()));
@@ -1318,9 +1326,12 @@ tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
1318
1326
newforOp->getResult (forOp.getInits ().size () + consumerResult.index ()));
1319
1327
}
1320
1328
1321
- // Need to erase the old for.
1329
+ rewriter.replaceOp (candidateSliceOp, clonedCandidateSliceOp);
1330
+
1331
+ // 7 - Need to erase the old scf.for.
1322
1332
rewriter.eraseOp (forOp);
1323
1333
rewriter.eraseOp (consumerOp);
1334
+ rewriter.eraseOp (clonedConsumerOp);
1324
1335
1325
1336
return scf::SCFFuseConsumerOfSliceResult{
1326
1337
consumerOp, tileAndFuseResult->tiledOps [0 ]->getResult (0 ), {}};
@@ -1373,13 +1384,7 @@ tileAndFuseConsumerOfSliceSCFForall(
1373
1384
unsigned operandNumber = (*consumerOpOperand)->getOperandNumber ();
1374
1385
unsigned resultNumber =
1375
1386
cast<OpResult>((*consumerOpOperand)->get ()).getResultNumber ();
1376
- // Check that the consumer results in exactly one value.
1377
- // TODO: Support fusion for consumers yielding more than one result.
1378
- if (consumerOp->getResults ().size () != 1 ) {
1379
- return rewriter.notifyMatchFailure (
1380
- consumerOp,
1381
- " only those consumers returning exactly one result are supported" );
1382
- }
1387
+
1383
1388
OpBuilder::InsertionGuard g (rewriter);
1384
1389
// Using candidateSliceOp->getParentOp() because we have the following case :-
1385
1390
// scf.forall.in_parallel {
@@ -1415,18 +1420,6 @@ tileAndFuseConsumerOfSliceSCFForall(
1415
1420
" consumer op taking the result of scf.forall as init is not supported" );
1416
1421
}
1417
1422
1418
- SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets ();
1419
- SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes ();
1420
- SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides ();
1421
-
1422
- // Check all insert stride is 1.
1423
- if (llvm::any_of (strides, [](OpFoldResult stride) {
1424
- return !isConstantIntValue (stride, 1 );
1425
- })) {
1426
- return rewriter.notifyMatchFailure (
1427
- candidateSliceOp, " containingOp's result yield with stride" );
1428
- }
1429
-
1430
1423
Location loc = forallOp.getLoc ();
1431
1424
// Create new scf.forall op.
1432
1425
SmallVector<Value> newOuts (forallOp.getOutputs ());
@@ -1444,49 +1437,100 @@ tileAndFuseConsumerOfSliceSCFForall(
1444
1437
loopBody, newLoopBody,
1445
1438
newLoopBody->getArguments ().take_front (loopBody->getNumArguments ()));
1446
1439
1447
- // Clone the consumer after the parallel_insert_slice.
1440
+ // 1 - Clone tensor.parallel_insert_slice after the original
1441
+ // tensor.parallel_insert_slice.
1448
1442
rewriter.setInsertionPointAfter (candidateSliceOp);
1449
- SmallVector<Value> newForOpBlockArgsForConsumerDest;
1450
- for (unsigned i = loopBody->getNumArguments (),
1451
- n = newLoopBody->getArguments ().size ();
1452
- i < n; i++) {
1453
- newForOpBlockArgsForConsumerDest.push_back (newLoopBody->getArgument (i));
1454
- }
1443
+ SmallVector<Value> candidateSliceOpOperands =
1444
+ llvm::to_vector (candidateSliceOp->getOperands ());
1445
+ tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
1446
+ mlir::clone (rewriter, candidateSliceOp,
1447
+ candidateSliceOp->getResultTypes (), candidateSliceOpOperands);
1448
+ LLVM_DEBUG (llvm::dbgs () << " Created a clone of the candidate slice op : "
1449
+ << clonedCandidateSliceOp << " \n " );
1450
+
1451
+ // 2 - Clone the consumer after the clone tensor.parallel_insert_slice.
1452
+ rewriter.setInsertionPointAfter (clonedCandidateSliceOp);
1453
+ SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector (
1454
+ newLoopBody->getArguments ().drop_front (loopBody->getNumArguments ()),
1455
+ [](BlockArgument b) -> Value { return b; });
1455
1456
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs (
1456
1457
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
1458
+ tileableConsumer = clonedConsumerOp;
1457
1459
1458
- // Replace scf.forall result's use in the consumer with parallel_insert_slice
1459
- // source.
1460
- rewriter.replaceAllUsesWith (forallOp.getResult (resultNumber),
1461
- candidateSliceOp.getSource ());
1460
+ // 2.b - Replace all uses of the scf.forall's result use in the consumer with
1461
+ // the source of the cloned tensor.parallel_insert_slice.
1462
+ rewriter.replaceUsesWithIf (forallOp.getResult (resultNumber),
1463
+ clonedCandidateSliceOp.getSource (),
1464
+ [&](OpOperand &operand) {
1465
+ return operand.getOwner () == clonedConsumerOp;
1466
+ });
1462
1467
1463
- // Generate the tiled implementation of the consumer of the source .
1464
- rewriter.setInsertionPoint (candidateSliceOp-> getParentOp ());
1468
+ // 3 - Perform tiling of the cloned consumer .
1469
+ rewriter.setInsertionPoint (newforallOp. getTerminator ());
1465
1470
FailureOr<TilingResult> tileAndFuseResult =
1466
1471
tensor::replaceInsertSliceWithTiledConsumer (
1467
1472
rewriter,
1468
- cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp.getOperation ()),
1473
+ cast<OffsetSizeAndStrideOpInterface>(
1474
+ clonedCandidateSliceOp.getOperation ()),
1469
1475
clonedConsumerOp->getOpOperand (operandNumber));
1470
1476
if (failed (tileAndFuseResult)) {
1471
1477
return rewriter.notifyMatchFailure (tileableConsumer,
1472
1478
" failed to tile consumer op: " );
1473
1479
}
1474
1480
1475
- // Update the source of the candidateSlice to be the cloned consumer.
1476
- rewriter.setInsertionPointAfter (candidateSliceOp);
1477
- SmallVector<Value> candidateSliceOpOperands =
1478
- llvm::to_vector (candidateSliceOp->getOperands ());
1479
- candidateSliceOpOperands[0 ] = tileAndFuseResult->tiledValues [0 ];
1480
- auto bbArgs = newforallOp.getBody ()->getArguments ();
1481
- candidateSliceOpOperands[1 ] =
1482
- bbArgs[forallOp.getRank () + forallOp.getOutputs ().size () + 0 ];
1483
- tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
1484
- mlir::clone (rewriter, candidateSliceOp,
1485
- candidateSliceOp->getResultTypes (), candidateSliceOpOperands);
1486
- LLVM_DEBUG (llvm::dbgs () << " Created a clone of the candidate slice op : "
1487
- << clonedCandidateSliceOp << " \n " );
1481
+ // 4 - Extract offset/sizes/strides required to create the
1482
+ // tensor.parallel_insert_slice for each result of the consumer.
1483
+ SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets ();
1484
+ SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes ();
1485
+ SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides ();
1486
+ // Check all insert stride is 1.
1487
+ if (llvm::any_of (strides, [](OpFoldResult stride) {
1488
+ return !isConstantIntValue (stride, 1 );
1489
+ })) {
1490
+ return rewriter.notifyMatchFailure (
1491
+ clonedCandidateSliceOp, " containingOp's result yield with stride" );
1492
+ }
1493
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1494
+ // Try to get iter domain position from input position.
1495
+ rewriter.setInsertionPointAfter (tileAndFuseResult->tiledOps [0 ]);
1496
+ ;
1497
+ if (failed (tileableConsumer.getIterDomainTilePositionFromOperandPosition (
1498
+ rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1499
+ iterDomainSizes))) {
1500
+ return rewriter.notifyMatchFailure (
1501
+ tileableConsumer, " can't get iter domain position from input position" );
1502
+ }
1488
1503
1489
- rewriter.eraseOp (clonedConsumerOp);
1504
+ // Try to get all containing op result's position from iter domain position.
1505
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
1506
+ llvm::SmallVector<OpFoldResult>>>
1507
+ resultPositions (clonedConsumerOp->getNumResults ());
1508
+ for (auto [idx, v] : llvm::enumerate (clonedConsumerOp->getResults ())) {
1509
+ if (failed (tileableConsumer.getResultTilePosition (
1510
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
1511
+ resultPositions[idx].first , resultPositions[idx].second ))) {
1512
+ return rewriter.notifyMatchFailure (
1513
+ tileableConsumer,
1514
+ " can't get result domain position from iter domain position" );
1515
+ }
1516
+ }
1517
+
1518
+ // 5 - Fix terminator.
1519
+ scf::InParallelOp newTerminatorOp = newforallOp.getTerminator ();
1520
+ SmallVector<Operation *> yieldingOps = llvm::map_to_vector (
1521
+ newTerminatorOp.getYieldingOps (), [](Operation &op) { return &op; });
1522
+ Operation *firstYieldOp = yieldingOps.front ();
1523
+ rewriter.setInsertionPoint (firstYieldOp);
1524
+ auto bbArgs = newforallOp.getBody ()->getArguments ();
1525
+ for (auto [idx, v] :
1526
+ llvm::enumerate (tileAndFuseResult->tiledOps [0 ]->getResults ())) {
1527
+ SmallVector<OpFoldResult> strides (resultPositions[idx].first .size (),
1528
+ rewriter.getIndexAttr (1 ));
1529
+ rewriter.create <tensor::ParallelInsertSliceOp>(
1530
+ firstYieldOp->getLoc (), v,
1531
+ bbArgs[forallOp.getRank () + forallOp.getOutputs ().size () + idx],
1532
+ resultPositions[idx].first , resultPositions[idx].second , strides);
1533
+ }
1490
1534
1491
1535
// Replace the result of scf.forall and consumer op.
1492
1536
for (auto result : llvm::enumerate (forallOp.getResults ())) {
@@ -1501,9 +1545,12 @@ tileAndFuseConsumerOfSliceSCFForall(
1501
1545
consumerResult.index ()));
1502
1546
}
1503
1547
1504
- // Need to erase the old scf.forall and consumer.
1548
+ // Need to erase the old scf.forall, consumer, cloned consumer and
1549
+ // candidateSliceOp.
1505
1550
rewriter.eraseOp (forallOp);
1506
1551
rewriter.eraseOp (consumerOp);
1552
+ rewriter.eraseOp (clonedConsumerOp);
1553
+ rewriter.eraseOp (candidateSliceOp);
1507
1554
1508
1555
return scf::SCFFuseConsumerOfSliceResult{
1509
1556
consumerOp, tileAndFuseResult->tiledOps [0 ]->getResult (0 ), {}};
0 commit comments