Skip to content

Commit 5746940

Browse files
committed
[mlir] Add option for a cleanup pattern set to SCF tiling helper
The SCF helper for tiling an operation implementing the TilingInterface and greedily fusing consumers requires an uninterrupted chain of operations implementing the tiling interface to succeed. There can be cases with intermediate ops that don't implement the interface but have producers that could be fused if various canonicalization/simplification patterns could run in between fusion steps. This adds an option to SCFTileAndFuseOptions for a pattern set to run between fusion steps to the ops that result from fusion/tiling. Removed and newly inserted slices are tracked for continued fusion applications. See this RFC for more discussion: https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155
1 parent 1833d41 commit 5746940

File tree

5 files changed

+286
-28
lines changed

5 files changed

+286
-28
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,18 +284,23 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
284284
let description = [{
285285
Tiles the operations pointed to by the target handle and fuses their
286286
producers greedily using the options provided as attributes.
287+
288+
If `apply_cleanup` is true then slice canonicalization is applied between
289+
fusion steps.
287290
}];
288291

289292
let arguments =
290293
(ins TransformHandleTypeInterface:$target,
291294
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
292-
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
295+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
296+
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
293297
let results = (outs TransformHandleTypeInterface:$transformed,
294298
Variadic<TransformHandleTypeInterface>:$loops);
295299

296300
let assemblyFormat = [{
297301
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
298-
attr-dict `:` functional-type(operands, results)
302+
(`apply_cleanup` `=` $apply_cleanup^)? attr-dict
303+
`:` functional-type(operands, results)
299304
}];
300305
let hasVerifier = 1;
301306
}

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Interfaces/LoopLikeInterface.h"
1616
#include "mlir/Interfaces/TilingInterface.h"
1717
#include "mlir/Interfaces/ViewLikeInterface.h"
18+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
1819

1920
#include <deque>
2021

@@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions {
153154
fusionControlFn = controlFn;
154155
return *this;
155156
}
157+
158+
/// An optional set of rewrite patterns to apply to the results of tiling
159+
/// before fusion. This will track deleted and newly inserted
160+
/// `tensor.extract_slice` ops and update the worklist.
161+
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
156162
};
157163

158164
/// Fuse the producer of the source of `candidateSliceOp` by computing the

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
557557
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
558558
scf::SCFTileAndFuseOptions tileAndFuseOptions;
559559
tileAndFuseOptions.tilingOptions = tilingOptions;
560+
561+
if (getApplyCleanup()) {
562+
MLIRContext *context = rewriter.getContext();
563+
RewritePatternSet patterns(context);
564+
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
565+
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
566+
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
567+
}
568+
560569
LogicalResult result = applyTilingToAll(
561570
rewriter, getOperation(), state.getPayloadOps(getTarget()),
562571
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 199 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "mlir/IR/PatternMatch.h"
2525
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2626
#include "mlir/Interfaces/TilingInterface.h"
27+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
28+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2729
#include "llvm/ADT/TypeSwitch.h"
2830
#include "llvm/Support/Debug.h"
2931
#include <optional>
@@ -1315,6 +1317,172 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
13151317
return generatedSlices;
13161318
}
13171319

1320+
namespace {
1321+
1322+
//===----------------------------------------------------------------------===//
1323+
// SliceWorklist
1324+
//===----------------------------------------------------------------------===//
1325+
1326+
/// Struct for tracking the number of stale entries on the worklist and whether
1327+
/// there is a remaining valid entry.
1328+
struct EntryCount {
1329+
bool isValid = true;
1330+
unsigned count = 0;
1331+
};
1332+
1333+
/// A FIFO worklist of operations with efficient removal and set semantics.
1334+
///
1335+
/// This class maintains a queue of operations and a mapping of operations to
1336+
/// positions in the vector, so that operations can be removed efficiently at
1337+
/// random. When an operation is removed, it is replaced with nullptr. Such
1338+
/// nullptr are skipped when pop'ing elements.
1339+
///
1340+
/// This is similar to the worklist used by the GreedyPatternRewriteDriver,
1341+
/// except instead FIFO so that slices for fusion can be processed breadth
1342+
/// first.
1343+
class SliceWorklist {
1344+
public:
1345+
SliceWorklist() = default;
1346+
1347+
/// Push an operation to the end of the worklist. This assumes that
1348+
/// the given operation is not already on the worklist.
1349+
void push(Operation *op);
1350+
1351+
/// Pop the an operation from the end of the worklist. Returns nullptr if
1352+
/// there are no remaining valid operations.
1353+
Operation *pop();
1354+
1355+
/// Remove an operation from the worklist.
1356+
void remove(Operation *op);
1357+
1358+
protected:
1359+
/// The queue of operations.
1360+
std::deque<Operation *> list;
1361+
1362+
/// A mapping of operations to the number of stale copies in the queue.
1363+
DenseMap<Operation *, EntryCount> map;
1364+
};
1365+
1366+
void SliceWorklist::push(Operation *op) {
1367+
assert(op && "cannot push nullptr to worklist");
1368+
list.push_back(op);
1369+
EntryCount newCount = map.lookup(op);
1370+
// Because operations are only pushed on creation, valid duplicates are
1371+
// never added.
1372+
assert((!map.contains(op) || !newCount.isValid) &&
1373+
"cannot push a duplicate operation");
1374+
map[op] = {/*isValid=*/true, newCount.count + 1};
1375+
}
1376+
1377+
Operation *SliceWorklist::pop() {
1378+
// Pop the front of the queue until we hit a valid entry.
1379+
while (!list.empty()) {
1380+
Operation *op = list.front();
1381+
list.pop_front();
1382+
1383+
EntryCount e = map.lookup(op);
1384+
// If the entry count is greater than 1 or there is no valid entry,
1385+
// this must be a stale entry. Decrement the map entry by one and continue.
1386+
if (e.count > 1 || !e.isValid) {
1387+
int64_t newCount = e.count - 1;
1388+
if (newCount <= 0)
1389+
map.erase(op);
1390+
else
1391+
map[op] = {e.isValid, static_cast<unsigned int>(newCount)};
1392+
continue;
1393+
}
1394+
1395+
map.erase(op);
1396+
return op;
1397+
}
1398+
return nullptr;
1399+
}
1400+
1401+
// Mark the operation as invalid if present. Removal from the map will
1402+
// happen later when popping from the worklist.
1403+
void SliceWorklist::remove(Operation *op) {
1404+
if (!map.contains(op))
1405+
return;
1406+
1407+
EntryCount e = map.lookup(op);
1408+
map[op] = {/*isValid=*/false, e.count};
1409+
}
1410+
1411+
//===----------------------------------------------------------------------===//
1412+
// SliceTrackingListener
1413+
//===----------------------------------------------------------------------===//
1414+
1415+
/// This class is a listener for tracking the insertion and removal of
1416+
/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1417+
/// fusion algorithm to apply cleanup patterns in between fusion steps.
1418+
class SliceTrackingListener : public RewriterBase::Listener {
1419+
public:
1420+
explicit SliceTrackingListener(
1421+
std::optional<FrozenRewritePatternSet> patterns);
1422+
SliceTrackingListener() = default;
1423+
1424+
/// Adds the given list of operations to the worklist, and if present, applies
1425+
/// the list of `patterns` to the newly added operations. This only processes
1426+
/// the given operations and any newly inserted ones by the pattern set.
1427+
LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1428+
1429+
/// Add to the new operation worklist if it is an extract_slice.
1430+
void notifyOperationInserted(Operation *op,
1431+
OpBuilder::InsertPoint previous) override;
1432+
1433+
/// Remove the operation from the worklist.
1434+
void notifyOperationErased(Operation *op) override;
1435+
1436+
/// Remove the operation from the worklist.
1437+
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1438+
1439+
/// The worklist for this transformation keeps track of the operations that
1440+
/// need to be (re)visited.
1441+
SliceWorklist worklist;
1442+
1443+
private:
1444+
/// Optional pattern set to apply when adding new operations to the worklist.
1445+
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1446+
};
1447+
1448+
SliceTrackingListener::SliceTrackingListener(
1449+
std::optional<FrozenRewritePatternSet> p) {
1450+
patterns = std::move(p);
1451+
}
1452+
1453+
LogicalResult
1454+
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1455+
for (Operation *op : ops) {
1456+
if (isa<tensor::ExtractSliceOp>(op))
1457+
worklist.push(op);
1458+
}
1459+
1460+
if (!patterns)
1461+
return success();
1462+
1463+
GreedyRewriteConfig config;
1464+
config.listener = this;
1465+
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1466+
return applyOpPatternsAndFold(ops, patterns.value(), config);
1467+
}
1468+
1469+
void SliceTrackingListener::notifyOperationInserted(
1470+
Operation *op, OpBuilder::InsertPoint previous) {
1471+
if (!isa<tensor::ExtractSliceOp>(op))
1472+
return;
1473+
worklist.push(op);
1474+
}
1475+
1476+
void SliceTrackingListener::notifyOperationErased(Operation *op) {
1477+
worklist.remove(op);
1478+
}
1479+
1480+
void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1481+
ValueRange replacement) {
1482+
worklist.remove(op);
1483+
}
1484+
} // namespace
1485+
13181486
/// Implementation of tile consumer and fuse producer greedily.
13191487
FailureOr<scf::SCFTileAndFuseResult>
13201488
mlir::scf::tileConsumerAndFuseProducersUsingSCF(
@@ -1370,33 +1538,33 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
13701538
tensor::ExtractSliceOp candidateSlice;
13711539
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
13721540
};
1373-
std::deque<WorklistItem> worklist;
1374-
auto addCandidateSlices = [&worklist, &options,
1375-
&loops](ArrayRef<Operation *> candidates) {
1376-
for (auto candidate : candidates) {
1377-
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378-
if (!sliceOp || sliceOp.use_empty())
1379-
continue;
13801541

1381-
auto [fusableProducer, destinationInitArg] =
1382-
getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
1383-
if (!fusableProducer)
1384-
continue;
1385-
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386-
options.fusionControlFn(sliceOp, fusableProducer,
1387-
destinationInitArg.has_value());
1388-
if (!controlFnResult)
1389-
continue;
1390-
worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
1391-
}
1392-
};
1542+
SliceTrackingListener sliceTracker =
1543+
SliceTrackingListener(options.cleanupPatterns);
13931544

1394-
addCandidateSlices(tilingResult->generatedSlices);
1545+
if (failed(
1546+
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1547+
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1548+
}
13951549
OpBuilder::InsertionGuard g(rewriter);
1396-
while (!worklist.empty()) {
1397-
// Traverse the slices in BFS fashion.
1398-
WorklistItem worklistItem = worklist.front();
1399-
worklist.pop_front();
1550+
while (Operation *next = sliceTracker.worklist.pop()) {
1551+
auto candidateSlice = dyn_cast<tensor::ExtractSliceOp>(next);
1552+
if (!candidateSlice)
1553+
continue;
1554+
1555+
auto [fusableProducer, destinationInitArg] =
1556+
getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1557+
loops);
1558+
if (!fusableProducer)
1559+
continue;
1560+
1561+
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1562+
options.fusionControlFn(candidateSlice, fusableProducer,
1563+
destinationInitArg.has_value());
1564+
if (!controlFnResult)
1565+
continue;
1566+
1567+
WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
14001568

14011569
// The operands of the fused producer might themselved be slices of
14021570
// values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1575,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14071575
if (!fusedResult)
14081576
continue;
14091577

1578+
SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1579+
14101580
if (worklistItem.controlFnResult.yieldProducerReplacement) {
14111581
// Reconstruct and yield all opResult of fusableProducerOp by default. The
14121582
// caller can specific which one to yield by designating optional argument
@@ -1421,20 +1591,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14211591
fusableProducerOp, "failed to replacement value for this "
14221592
"operation from within the tiled loop");
14231593
}
1424-
addCandidateSlices(newSlices.value());
1594+
worklistCandidates.append(newSlices.value());
14251595
for (auto [index, result] :
14261596
llvm::enumerate(fusableProducerOp->getResults())) {
14271597
origValToResultNumber[result] = loops.front()->getNumResults() -
14281598
fusableProducerOp->getNumResults() +
14291599
index;
14301600
}
14311601
}
1432-
addCandidateSlices(fusedResult->generatedSlices);
14331602
if (Operation *tiledAndFusedOp =
14341603
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
14351604
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
14361605
tiledAndFusedOps.insert(tiledAndFusedOp);
14371606
}
1607+
1608+
if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1609+
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1610+
}
14381611
}
14391612

14401613
DenseMap<Value, Value> replacements;

0 commit comments

Comments
 (0)