Skip to content

Commit be547c6

Browse files
committed
Remove worklist class and add negative test
1 parent 66748e7 commit be547c6

File tree

2 files changed

+67
-101
lines changed

2 files changed

+67
-101
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 32 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,95 +1319,6 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
13191319

13201320
namespace {
13211321

1322-
//===----------------------------------------------------------------------===//
1323-
// SliceWorklist
1324-
//===----------------------------------------------------------------------===//
1325-
1326-
/// Struct for tracking the number of stale entries on the worklist and whether
1327-
/// there is a remaining valid entry.
1328-
struct EntryCount {
1329-
bool isValid = true;
1330-
unsigned count = 0;
1331-
};
1332-
1333-
/// A FIFO worklist of operations with efficient removal and set semantics.
1334-
///
1335-
/// This class maintains a queue of operations and a mapping of operations to
1336-
/// positions in the vector, so that operations can be removed efficiently at
1337-
/// random. When an operation is removed, it is replaced with nullptr. Such
1338-
/// nullptr are skipped when pop'ing elements.
1339-
///
1340-
/// This is similar to the worklist used by the GreedyPatternRewriteDriver,
1341-
/// except instead FIFO so that slices for fusion can be processed breadth
1342-
/// first.
1343-
class SliceWorklist {
1344-
public:
1345-
SliceWorklist() = default;
1346-
1347-
/// Push an operation to the end of the worklist. This assumes that
1348-
/// the given operation is not already on the worklist.
1349-
void push(Operation *op);
1350-
1351-
/// Pop the an operation from the end of the worklist. Returns nullptr if
1352-
/// there are no remaining valid operations.
1353-
Operation *pop();
1354-
1355-
/// Remove an operation from the worklist.
1356-
void remove(Operation *op);
1357-
1358-
protected:
1359-
/// The queue of operations.
1360-
std::deque<Operation *> list;
1361-
1362-
/// A mapping of operations to the number of stale copies in the queue.
1363-
DenseMap<Operation *, EntryCount> map;
1364-
};
1365-
1366-
void SliceWorklist::push(Operation *op) {
1367-
assert(op && "cannot push nullptr to worklist");
1368-
list.push_back(op);
1369-
EntryCount newCount = map.lookup(op);
1370-
// Because operations are only pushed on creation, valid duplicates are
1371-
// never added.
1372-
assert((!map.contains(op) || !newCount.isValid) &&
1373-
"cannot push a duplicate operation");
1374-
map[op] = {/*isValid=*/true, newCount.count + 1};
1375-
}
1376-
1377-
Operation *SliceWorklist::pop() {
1378-
// Pop the front of the queue until we hit a valid entry.
1379-
while (!list.empty()) {
1380-
Operation *op = list.front();
1381-
list.pop_front();
1382-
1383-
EntryCount e = map.lookup(op);
1384-
// If the entry count is greater than 1 or there is no valid entry,
1385-
// this must be a stale entry. Decrement the map entry by one and continue.
1386-
if (e.count > 1 || !e.isValid) {
1387-
int64_t newCount = e.count - 1;
1388-
if (newCount <= 0)
1389-
map.erase(op);
1390-
else
1391-
map[op] = {e.isValid, static_cast<unsigned int>(newCount)};
1392-
continue;
1393-
}
1394-
1395-
map.erase(op);
1396-
return op;
1397-
}
1398-
return nullptr;
1399-
}
1400-
1401-
// Mark the operation as invalid if present. Removal from the map will
1402-
// happen later when popping from the worklist.
1403-
void SliceWorklist::remove(Operation *op) {
1404-
if (!map.contains(op))
1405-
return;
1406-
1407-
EntryCount e = map.lookup(op);
1408-
map[op] = {/*isValid=*/false, e.count};
1409-
}
1410-
14111322
//===----------------------------------------------------------------------===//
14121323
// SliceTrackingListener
14131324
//===----------------------------------------------------------------------===//
@@ -1430,15 +1341,18 @@ class SliceTrackingListener : public RewriterBase::Listener {
14301341
void notifyOperationInserted(Operation *op,
14311342
OpBuilder::InsertPoint previous) override;
14321343

1344+
/// Shared helper for operation removal from the worklist.
1345+
void removeOp(Operation *op);
1346+
14331347
/// Remove the operation from the worklist.
14341348
void notifyOperationErased(Operation *op) override;
14351349

14361350
/// Remove the operation from the worklist.
14371351
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
14381352

1439-
/// The worklist for this transformation keeps track of the operations that
1440-
/// need to be (re)visited.
1441-
SliceWorklist worklist;
1353+
/// The worklist for this transformation keeps track of the slices to visit
1354+
/// next for fusion.
1355+
std::deque<tensor::ExtractSliceOp> worklist;
14421356

14431357
private:
14441358
/// Optional pattern set to apply when adding new operations to the worklist.
@@ -1453,8 +1367,8 @@ SliceTrackingListener::SliceTrackingListener(
14531367
LogicalResult
14541368
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
14551369
for (Operation *op : ops) {
1456-
if (isa<tensor::ExtractSliceOp>(op))
1457-
worklist.push(op);
1370+
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1371+
worklist.push_back(slice);
14581372
}
14591373

14601374
if (!patterns)
@@ -1468,18 +1382,36 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
14681382

14691383
void SliceTrackingListener::notifyOperationInserted(
14701384
Operation *op, OpBuilder::InsertPoint previous) {
1385+
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386+
if (!slice)
1387+
return;
1388+
worklist.push_back(slice);
1389+
}
1390+
1391+
// Scan the worklist for the given op and remove it if present. The expectation
1392+
// is for the worklist to be small and for removal to be relatively rare.
1393+
void SliceTrackingListener::removeOp(Operation *op) {
14711394
if (!isa<tensor::ExtractSliceOp>(op))
14721395
return;
1473-
worklist.push(op);
1396+
auto iter = worklist.begin();
1397+
while (iter != worklist.end()) {
1398+
if (*iter == op)
1399+
break;
1400+
iter++;
1401+
}
1402+
if (iter == worklist.end())
1403+
return;
1404+
1405+
worklist.erase(iter);
14741406
}
14751407

14761408
void SliceTrackingListener::notifyOperationErased(Operation *op) {
1477-
worklist.remove(op);
1409+
removeOp(op);
14781410
}
14791411

14801412
void SliceTrackingListener::notifyOperationReplaced(Operation *op,
14811413
ValueRange replacement) {
1482-
worklist.remove(op);
1414+
removeOp(op);
14831415
}
14841416
} // namespace
14851417

@@ -1547,10 +1479,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15471479
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
15481480
}
15491481
OpBuilder::InsertionGuard g(rewriter);
1550-
while (Operation *next = sliceTracker.worklist.pop()) {
1551-
auto candidateSlice = dyn_cast<tensor::ExtractSliceOp>(next);
1552-
if (!candidateSlice)
1553-
continue;
1482+
while (!sliceTracker.worklist.empty()) {
1483+
auto candidateSlice = sliceTracker.worklist.front();
1484+
sliceTracker.worklist.pop_front();
15541485

15551486
auto [fusableProducer, destinationInitArg] =
15561487
getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,38 @@ module attributes {transform.with_named_sequence} {
243243
transform.yield
244244
}
245245
}
246+
247+
// -----
248+
249+
// CHECK-LABEL: func.func @fuse_unrelated_slice
250+
func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<10x10xf32>) {
251+
252+
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice
253+
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]]
254+
// CHECK: %[[RES:.*]] = scf.for
255+
// CHECK: scf.for
256+
// CHECK: linalg.elemwise_unary
257+
// CHECK: linalg.elemwise_binary
258+
// CHECK: return %[[RES]], %[[SLICE2]]
259+
%c0 = arith.constant 0 : index
260+
%c1 = arith.constant 1 : index
261+
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
262+
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
263+
%slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
264+
%slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor<?x?xf32> to tensor<10x10xf32>
265+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
266+
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
267+
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
268+
%2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
269+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
270+
return %2, %slice2 : tensor<?x?xf32>, tensor<10x10xf32>
271+
}
272+
273+
module attributes {transform.with_named_sequence} {
274+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
275+
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
276+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
277+
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
278+
transform.yield
279+
}
280+
}

0 commit comments

Comments
 (0)