@@ -1319,95 +1319,6 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1319
1319
1320
1320
namespace {
1321
1321
1322
- // ===----------------------------------------------------------------------===//
1323
- // SliceWorklist
1324
- // ===----------------------------------------------------------------------===//
1325
-
1326
- // / Struct for tracking the number of stale entries on the worklist and whether
1327
- // / there is a remaining valid entry.
1328
- struct EntryCount {
1329
- bool isValid = true ;
1330
- unsigned count = 0 ;
1331
- };
1332
-
1333
- // / A FIFO worklist of operations with efficient removal and set semantics.
1334
- // /
1335
- // / This class maintains a queue of operations and a mapping of operations to
1336
- // / positions in the vector, so that operations can be removed efficiently at
1337
- // / random. When an operation is removed, it is replaced with nullptr. Such
1338
- // / nullptr are skipped when pop'ing elements.
1339
- // /
1340
- // / This is similar to the worklist used by the GreedyPatternRewriteDriver,
1341
- // / except instead FIFO so that slices for fusion can be processed breadth
1342
- // / first.
1343
- class SliceWorklist {
1344
- public:
1345
- SliceWorklist () = default ;
1346
-
1347
- // / Push an operation to the end of the worklist. This assumes that
1348
- // / the given operation is not already on the worklist.
1349
- void push (Operation *op);
1350
-
1351
- // / Pop the an operation from the end of the worklist. Returns nullptr if
1352
- // / there are no remaining valid operations.
1353
- Operation *pop ();
1354
-
1355
- // / Remove an operation from the worklist.
1356
- void remove (Operation *op);
1357
-
1358
- protected:
1359
- // / The queue of operations.
1360
- std::deque<Operation *> list;
1361
-
1362
- // / A mapping of operations to the number of stale copies in the queue.
1363
- DenseMap<Operation *, EntryCount> map;
1364
- };
1365
-
1366
- void SliceWorklist::push (Operation *op) {
1367
- assert (op && " cannot push nullptr to worklist" );
1368
- list.push_back (op);
1369
- EntryCount newCount = map.lookup (op);
1370
- // Because operations are only pushed on creation, valid duplicates are
1371
- // never added.
1372
- assert ((!map.contains (op) || !newCount.isValid ) &&
1373
- " cannot push a duplicate operation" );
1374
- map[op] = {/* isValid=*/ true , newCount.count + 1 };
1375
- }
1376
-
1377
- Operation *SliceWorklist::pop () {
1378
- // Pop the front of the queue until we hit a valid entry.
1379
- while (!list.empty ()) {
1380
- Operation *op = list.front ();
1381
- list.pop_front ();
1382
-
1383
- EntryCount e = map.lookup (op);
1384
- // If the entry count is greater than 1 or there is no valid entry,
1385
- // this must be a stale entry. Decrement the map entry by one and continue.
1386
- if (e.count > 1 || !e.isValid ) {
1387
- int64_t newCount = e.count - 1 ;
1388
- if (newCount <= 0 )
1389
- map.erase (op);
1390
- else
1391
- map[op] = {e.isValid , static_cast <unsigned int >(newCount)};
1392
- continue ;
1393
- }
1394
-
1395
- map.erase (op);
1396
- return op;
1397
- }
1398
- return nullptr ;
1399
- }
1400
-
1401
- // Mark the operation as invalid if present. Removal from the map will
1402
- // happen later when popping from the worklist.
1403
- void SliceWorklist::remove (Operation *op) {
1404
- if (!map.contains (op))
1405
- return ;
1406
-
1407
- EntryCount e = map.lookup (op);
1408
- map[op] = {/* isValid=*/ false , e.count };
1409
- }
1410
-
1411
1322
// ===----------------------------------------------------------------------===//
1412
1323
// SliceTrackingListener
1413
1324
// ===----------------------------------------------------------------------===//
@@ -1430,15 +1341,18 @@ class SliceTrackingListener : public RewriterBase::Listener {
1430
1341
void notifyOperationInserted (Operation *op,
1431
1342
OpBuilder::InsertPoint previous) override ;
1432
1343
1344
+ // / Shared helper for operation removal from the worklist.
1345
+ void removeOp (Operation *op);
1346
+
1433
1347
// / Remove the operation from the worklist.
1434
1348
void notifyOperationErased (Operation *op) override ;
1435
1349
1436
1350
// / Remove the operation from the worklist.
1437
1351
void notifyOperationReplaced (Operation *op, ValueRange replacement) override ;
1438
1352
1439
- // / The worklist for this transformation keeps track of the operations that
1440
- // / need to be (re)visited .
1441
- SliceWorklist worklist;
1353
+ // / The worklist for this transformation keeps track of the slices to visit
1354
+ // / next for fusion .
1355
+ std::deque<tensor::ExtractSliceOp> worklist;
1442
1356
1443
1357
private:
1444
1358
// / Optional pattern set to apply when adding new operations to the worklist.
@@ -1453,8 +1367,8 @@ SliceTrackingListener::SliceTrackingListener(
1453
1367
LogicalResult
1454
1368
SliceTrackingListener::insertAndApplyPatterns (ArrayRef<Operation *> ops) {
1455
1369
for (Operation *op : ops) {
1456
- if (isa <tensor::ExtractSliceOp>(op))
1457
- worklist.push (op );
1370
+ if (auto slice = dyn_cast <tensor::ExtractSliceOp>(op))
1371
+ worklist.push_back (slice );
1458
1372
}
1459
1373
1460
1374
if (!patterns)
@@ -1468,18 +1382,36 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1468
1382
1469
1383
void SliceTrackingListener::notifyOperationInserted (
1470
1384
Operation *op, OpBuilder::InsertPoint previous) {
1385
+ auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386
+ if (!slice)
1387
+ return ;
1388
+ worklist.push_back (slice);
1389
+ }
1390
+
1391
+ // Scan the worklist for the given op and remove it if present. The expectation
1392
+ // is for the worklist to be small and for removal to be relatively rare.
1393
+ void SliceTrackingListener::removeOp (Operation *op) {
1471
1394
if (!isa<tensor::ExtractSliceOp>(op))
1472
1395
return ;
1473
- worklist.push (op);
1396
+ auto iter = worklist.begin ();
1397
+ while (iter != worklist.end ()) {
1398
+ if (*iter == op)
1399
+ break ;
1400
+ iter++;
1401
+ }
1402
+ if (iter == worklist.end ())
1403
+ return ;
1404
+
1405
+ worklist.erase (iter);
1474
1406
}
1475
1407
1476
1408
void SliceTrackingListener::notifyOperationErased (Operation *op) {
1477
- worklist. remove (op);
1409
+ removeOp (op);
1478
1410
}
1479
1411
1480
1412
void SliceTrackingListener::notifyOperationReplaced (Operation *op,
1481
1413
ValueRange replacement) {
1482
- worklist. remove (op);
1414
+ removeOp (op);
1483
1415
}
1484
1416
} // namespace
1485
1417
@@ -1547,10 +1479,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1547
1479
return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1548
1480
}
1549
1481
OpBuilder::InsertionGuard g (rewriter);
1550
- while (Operation *next = sliceTracker.worklist .pop ()) {
1551
- auto candidateSlice = dyn_cast<tensor::ExtractSliceOp>(next);
1552
- if (!candidateSlice)
1553
- continue ;
1482
+ while (!sliceTracker.worklist .empty ()) {
1483
+ auto candidateSlice = sliceTracker.worklist .front ();
1484
+ sliceTracker.worklist .pop_front ();
1554
1485
1555
1486
auto [fusableProducer, destinationInitArg] =
1556
1487
getUntiledProducerFromSliceSource (&candidateSlice.getSourceMutable (),
0 commit comments