24
24
#include " mlir/IR/PatternMatch.h"
25
25
#include " mlir/Interfaces/DestinationStyleOpInterface.h"
26
26
#include " mlir/Interfaces/TilingInterface.h"
27
+ #include " mlir/Rewrite/FrozenRewritePatternSet.h"
28
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27
29
#include " llvm/ADT/TypeSwitch.h"
28
30
#include " llvm/Support/Debug.h"
29
31
#include < optional>
@@ -1315,6 +1317,172 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1315
1317
return generatedSlices;
1316
1318
}
1317
1319
1320
+ namespace {
1321
+
1322
+ // ===----------------------------------------------------------------------===//
1323
+ // SliceWorklist
1324
+ // ===----------------------------------------------------------------------===//
1325
+
1326
+ // / Struct for tracking the number of stale entries on the worklist and whether
1327
+ // / there is a remaining valid entry.
1328
+ struct EntryCount {
1329
+ bool isValid = true ;
1330
+ unsigned count = 0 ;
1331
+ };
1332
+
1333
+ // / A FIFO worklist of operations with efficient removal and set semantics.
1334
+ // /
1335
+ // / This class maintains a queue of operations and a mapping of operations to
1336
+ // / positions in the vector, so that operations can be removed efficiently at
1337
+ // / random. When an operation is removed, it is replaced with nullptr. Such
1338
+ // / nullptr are skipped when pop'ing elements.
1339
+ // /
1340
+ // / This is similar to the worklist used by the GreedyPatternRewriteDriver,
1341
+ // / except instead FIFO so that slices for fusion can be processed breadth
1342
+ // / first.
1343
+ class SliceWorklist {
1344
+ public:
1345
+ SliceWorklist () = default ;
1346
+
1347
+ // / Push an operation to the end of the worklist. This assumes that
1348
+ // / the given operation is not already on the worklist.
1349
+ void push (Operation *op);
1350
+
1351
+ // / Pop the an operation from the end of the worklist. Returns nullptr if
1352
+ // / there are no remaining valid operations.
1353
+ Operation *pop ();
1354
+
1355
+ // / Remove an operation from the worklist.
1356
+ void remove (Operation *op);
1357
+
1358
+ protected:
1359
+ // / The queue of operations.
1360
+ std::deque<Operation *> list;
1361
+
1362
+ // / A mapping of operations to the number of stale copies in the queue.
1363
+ DenseMap<Operation *, EntryCount> map;
1364
+ };
1365
+
1366
+ void SliceWorklist::push (Operation *op) {
1367
+ assert (op && " cannot push nullptr to worklist" );
1368
+ list.push_back (op);
1369
+ EntryCount newCount = map.lookup (op);
1370
+ // Because operations are only pushed on creation, valid duplicates are
1371
+ // never added.
1372
+ assert ((!map.contains (op) || !newCount.isValid ) &&
1373
+ " cannot push a duplicate operation" );
1374
+ map[op] = {/* isValid=*/ true , newCount.count + 1 };
1375
+ }
1376
+
1377
+ Operation *SliceWorklist::pop () {
1378
+ // Pop the front of the queue until we hit a valid entry.
1379
+ while (!list.empty ()) {
1380
+ Operation *op = list.front ();
1381
+ list.pop_front ();
1382
+
1383
+ EntryCount e = map.lookup (op);
1384
+ // If the entry count is greater than 1 or there is no valid entry,
1385
+ // this must be a stale entry. Decrement the map entry by one and continue.
1386
+ if (e.count > 1 || !e.isValid ) {
1387
+ int64_t newCount = e.count - 1 ;
1388
+ if (newCount <= 0 )
1389
+ map.erase (op);
1390
+ else
1391
+ map[op] = {e.isValid , static_cast <unsigned int >(newCount)};
1392
+ continue ;
1393
+ }
1394
+
1395
+ map.erase (op);
1396
+ return op;
1397
+ }
1398
+ return nullptr ;
1399
+ }
1400
+
1401
+ // Mark the operation as invalid if present. Removal from the map will
1402
+ // happen later when popping from the worklist.
1403
+ void SliceWorklist::remove (Operation *op) {
1404
+ if (!map.contains (op))
1405
+ return ;
1406
+
1407
+ EntryCount e = map.lookup (op);
1408
+ map[op] = {/* isValid=*/ false , e.count };
1409
+ }
1410
+
1411
+ // ===----------------------------------------------------------------------===//
1412
+ // SliceTrackingListener
1413
+ // ===----------------------------------------------------------------------===//
1414
+
1415
+ // / This class is a listener for tracking the insertion and removal of
1416
+ // / `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1417
+ // / fusion algorithm to apply cleanup patterns in between fusion steps.
1418
+ class SliceTrackingListener : public RewriterBase ::Listener {
1419
+ public:
1420
+ explicit SliceTrackingListener (
1421
+ std::optional<FrozenRewritePatternSet> patterns);
1422
+ SliceTrackingListener () = default ;
1423
+
1424
+ // / Adds the given list of operations to the worklist, and if present, applies
1425
+ // / the list of `patterns` to the newly added operations. This only processes
1426
+ // / the given operations and any newly inserted ones by the pattern set.
1427
+ LogicalResult insertAndApplyPatterns (ArrayRef<Operation *> newOps);
1428
+
1429
+ // / Add to the new operation worklist if it is an extract_slice.
1430
+ void notifyOperationInserted (Operation *op,
1431
+ OpBuilder::InsertPoint previous) override ;
1432
+
1433
+ // / Remove the operation from the worklist.
1434
+ void notifyOperationErased (Operation *op) override ;
1435
+
1436
+ // / Remove the operation from the worklist.
1437
+ void notifyOperationReplaced (Operation *op, ValueRange replacement) override ;
1438
+
1439
+ // / The worklist for this transformation keeps track of the operations that
1440
+ // / need to be (re)visited.
1441
+ SliceWorklist worklist;
1442
+
1443
+ private:
1444
+ // / Optional pattern set to apply when adding new operations to the worklist.
1445
+ std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1446
+ };
1447
+
1448
+ SliceTrackingListener::SliceTrackingListener (
1449
+ std::optional<FrozenRewritePatternSet> p) {
1450
+ patterns = std::move (p);
1451
+ }
1452
+
1453
+ LogicalResult
1454
+ SliceTrackingListener::insertAndApplyPatterns (ArrayRef<Operation *> ops) {
1455
+ for (Operation *op : ops) {
1456
+ if (isa<tensor::ExtractSliceOp>(op))
1457
+ worklist.push (op);
1458
+ }
1459
+
1460
+ if (!patterns)
1461
+ return success ();
1462
+
1463
+ GreedyRewriteConfig config;
1464
+ config.listener = this ;
1465
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1466
+ return applyOpPatternsAndFold (ops, patterns.value (), config);
1467
+ }
1468
+
1469
+ void SliceTrackingListener::notifyOperationInserted (
1470
+ Operation *op, OpBuilder::InsertPoint previous) {
1471
+ if (!isa<tensor::ExtractSliceOp>(op))
1472
+ return ;
1473
+ worklist.push (op);
1474
+ }
1475
+
1476
+ void SliceTrackingListener::notifyOperationErased (Operation *op) {
1477
+ worklist.remove (op);
1478
+ }
1479
+
1480
+ void SliceTrackingListener::notifyOperationReplaced (Operation *op,
1481
+ ValueRange replacement) {
1482
+ worklist.remove (op);
1483
+ }
1484
+ } // namespace
1485
+
1318
1486
// / Implementation of tile consumer and fuse producer greedily.
1319
1487
FailureOr<scf::SCFTileAndFuseResult>
1320
1488
mlir::scf::tileConsumerAndFuseProducersUsingSCF (
@@ -1370,33 +1538,33 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1370
1538
tensor::ExtractSliceOp candidateSlice;
1371
1539
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1372
1540
};
1373
- std::deque<WorklistItem> worklist;
1374
- auto addCandidateSlices = [&worklist, &options,
1375
- &loops](ArrayRef<Operation *> candidates) {
1376
- for (auto candidate : candidates) {
1377
- auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378
- if (!sliceOp || sliceOp.use_empty ())
1379
- continue ;
1380
1541
1381
- auto [fusableProducer, destinationInitArg] =
1382
- getUntiledProducerFromSliceSource (&sliceOp.getSourceMutable (), loops);
1383
- if (!fusableProducer)
1384
- continue ;
1385
- std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386
- options.fusionControlFn (sliceOp, fusableProducer,
1387
- destinationInitArg.has_value ());
1388
- if (!controlFnResult)
1389
- continue ;
1390
- worklist.emplace_back (WorklistItem{sliceOp, controlFnResult.value ()});
1391
- }
1392
- };
1542
+ SliceTrackingListener sliceTracker =
1543
+ SliceTrackingListener (options.cleanupPatterns );
1393
1544
1394
- addCandidateSlices (tilingResult->generatedSlices );
1545
+ if (failed (
1546
+ sliceTracker.insertAndApplyPatterns (tilingResult->generatedSlices ))) {
1547
+ return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1548
+ }
1395
1549
OpBuilder::InsertionGuard g (rewriter);
1396
- while (!worklist.empty ()) {
1397
- // Traverse the slices in BFS fashion.
1398
- WorklistItem worklistItem = worklist.front ();
1399
- worklist.pop_front ();
1550
+ while (Operation *next = sliceTracker.worklist .pop ()) {
1551
+ auto candidateSlice = dyn_cast<tensor::ExtractSliceOp>(next);
1552
+ if (!candidateSlice)
1553
+ continue ;
1554
+
1555
+ auto [fusableProducer, destinationInitArg] =
1556
+ getUntiledProducerFromSliceSource (&candidateSlice.getSourceMutable (),
1557
+ loops);
1558
+ if (!fusableProducer)
1559
+ continue ;
1560
+
1561
+ std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1562
+ options.fusionControlFn (candidateSlice, fusableProducer,
1563
+ destinationInitArg.has_value ());
1564
+ if (!controlFnResult)
1565
+ continue ;
1566
+
1567
+ WorklistItem worklistItem = {candidateSlice, controlFnResult.value ()};
1400
1568
1401
1569
// The operands of the fused producer might themselved be slices of
1402
1570
// values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1575,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1407
1575
if (!fusedResult)
1408
1576
continue ;
1409
1577
1578
+ SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices ;
1579
+
1410
1580
if (worklistItem.controlFnResult .yieldProducerReplacement ) {
1411
1581
// Reconstruct and yield all opResult of fusableProducerOp by default. The
1412
1582
// caller can specific which one to yield by designating optional argument
@@ -1421,20 +1591,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1421
1591
fusableProducerOp, " failed to replacement value for this "
1422
1592
" operation from within the tiled loop" );
1423
1593
}
1424
- addCandidateSlices (newSlices.value ());
1594
+ worklistCandidates. append (newSlices.value ());
1425
1595
for (auto [index, result] :
1426
1596
llvm::enumerate (fusableProducerOp->getResults ())) {
1427
1597
origValToResultNumber[result] = loops.front ()->getNumResults () -
1428
1598
fusableProducerOp->getNumResults () +
1429
1599
index;
1430
1600
}
1431
1601
}
1432
- addCandidateSlices (fusedResult->generatedSlices );
1433
1602
if (Operation *tiledAndFusedOp =
1434
1603
fusedResult->tiledAndFusedProducer .getDefiningOp ()) {
1435
1604
fusedProducers.insert (fusedResult->origProducer .getDefiningOp ());
1436
1605
tiledAndFusedOps.insert (tiledAndFusedOp);
1437
1606
}
1607
+
1608
+ if (failed (sliceTracker.insertAndApplyPatterns (worklistCandidates))) {
1609
+ return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1610
+ }
1438
1611
}
1439
1612
1440
1613
DenseMap<Value, Value> replacements;
0 commit comments