Skip to content

Commit 4435ced

Browse files
[mlir][TilingInterface] Allow controlling what fusion is done within tile and fuse (#76871)
Currently the `tileConsumerAndFuseProducerGreedilyUsingSCFFor` method greedily fuses through all slices that are generated during the tile and fuse flow. That is not the normal use case. Ideally the caller would like to control which slices get fused and which dont. This patch introduces a new field to the `SCFTileAndFuseOptions` to specify this control. The contol function also allows the caller to specify if the replacement for the fused producer needs to be yielded from within the tiled computation. This allows replacing the fused producers in case they have other uses. Without this the original producers still survive negating the utility of the fusion. The change here also means that the name of the function `tileConsumerAndFuseProducerGreedily...` can be updated. Defering that to a later stage to reduce the churn of API changes.
1 parent ce1305a commit 4435ced

File tree

3 files changed

+105
-82
lines changed

3 files changed

+105
-82
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,30 @@ struct SCFTileAndFuseOptions {
9797
tilingOptions = options;
9898
return *this;
9999
}
100+
101+
/// Control function to check if a slice needs to be fused or not,
102+
/// The control function receives
103+
/// 1) the slice along which fusion is to be done,
104+
/// 2) the producer value that is to be fused
105+
/// 3) a boolean value set to `true` if the fusion is from
106+
/// a destination operand.
107+
/// It retuns two booleans
108+
/// - returns `true` if the fusion should be done through the candidate slice
109+
/// - returns `true` if a replacement for the fused producer needs to be
110+
/// yielded from within the tiled loop. Note that it is valid to return
111+
/// `true` only if the slice fused is disjoint across all iterations of the
112+
/// tiled loop. It is up to the caller to ensure that this is true for the
113+
/// fused producers.
114+
using ControlFnTy = std::function<std::tuple<bool, bool>(
115+
tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
116+
bool isDestinationOperand)>;
117+
ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
118+
return std::make_tuple(true, false);
119+
};
120+
SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
121+
fusionControlFn = controlFn;
122+
return *this;
123+
}
100124
};
101125

102126
/// Fuse the producer of the source of `candidateSliceOp` by computing the

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -728,32 +728,36 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
728728
}
729729

730730
// 1. First tile the consumer.
731-
SmallVector<scf::ForOp> forLoops;
732731
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
733-
DenseMap<Value, Value> replacements;
734-
llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
735-
{
736-
FailureOr<scf::SCFTilingResult> tilingResult =
737-
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
738-
if (failed(tilingResult))
739-
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
740-
for (auto *tiledOp : tilingResult->tiledOps)
741-
tiledAndFusedOps.insert(tiledOp);
742-
forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
743-
for (auto [index, origValue, replacement] :
744-
llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
745-
replacements[origValue] = replacement;
746-
yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
747-
index)] = index;
748-
}
749-
}
732+
llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
733+
FailureOr<scf::SCFTilingResult> tilingResult =
734+
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
735+
if (failed(tilingResult))
736+
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
737+
for (auto *tiledOp : tilingResult->tiledOps)
738+
tiledAndFusedOps.insert(tiledOp);
739+
SmallVector<scf::ForOp> forLoops =
740+
castToTypedOperations<scf::ForOp>(tilingResult->loops);
750741

751742
// If there are no loops generated, fusion is immaterial.
752743
if (forLoops.empty()) {
744+
DenseMap<Value, Value> replacements;
745+
for (auto [origVal, replacement] :
746+
llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
747+
replacements[origVal] = replacement;
748+
}
753749
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
754750
getAsOperations(forLoops), replacements};
755751
}
756752

753+
// To keep track of replacements for now just record the map from the original
754+
// untiled value to the result number of the for loop. Since the loop gets
755+
// potentially replaced during fusion, keeping the value directly wont work.
756+
DenseMap<Value, size_t> origValToResultNumber;
757+
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
758+
origValToResultNumber[result] = index;
759+
}
760+
757761
// 2. Typically, the operands of the tiled operation are slices of the
758762
// operands of the untiled operation. These are expressed in IR using
759763
// `tensor.extract_slice` operations with source being the operands of the
@@ -776,6 +780,18 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
776780
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
777781
candidates.pop_front();
778782

783+
// Find the original producer of the slice.
784+
auto [fusableProducer, destinationInitArg] =
785+
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
786+
forLoops);
787+
if (!fusableProducer)
788+
continue;
789+
790+
auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
791+
candidateSliceOp, fusableProducer, destinationInitArg.has_value());
792+
if (!fuseSlice)
793+
continue;
794+
779795
// The operands of the fused producer might themselved be slices of
780796
// values produced by operations that implement the `TilingInterface`.
781797
// Add these operations to the worklist.
@@ -784,13 +800,26 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
784800
if (!fusedResult)
785801
continue;
786802

803+
if (yieldReplacement) {
804+
yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
805+
fusedResult.value(), forLoops);
806+
origValToResultNumber[fusableProducer] =
807+
forLoops.front().getNumResults() - 1;
808+
}
809+
787810
if (Operation *tiledAndFusedOp =
788811
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
789812
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
790813
tiledAndFusedOps.insert(tiledAndFusedOp);
791814
addCandidateSlices(tiledAndFusedOp, candidates);
792815
}
793816
}
817+
818+
DenseMap<Value, Value> replacements;
819+
for (auto [origVal, resultNumber] : origValToResultNumber) {
820+
replacements[origVal] = forLoops.front()->getResult(resultNumber);
821+
}
822+
794823
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
795824
getAsOperations(forLoops), replacements};
796825
}

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Lines changed: 34 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -311,80 +311,50 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
311311
// Collect list of operations that can be tiled and fused.
312312
llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
313313
collectTiledAndFusedOps(rootOp);
314-
auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
315-
return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user) ||
316-
outerMostTiledLoop->isAncestor(user);
314+
llvm::SmallDenseMap<Operation *, bool> yielded;
315+
auto isIgnoredUser = [&](Operation *user) {
316+
return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user);
317317
};
318-
319-
// The rest of this method is similar to
320-
// scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp, except that also
321-
// yields replacements for values of the fused producer.
322-
323-
// 1. Tile the consumer.
324-
SmallVector<OpResult> yieldedValuesToOrigValues;
325-
FailureOr<scf::SCFTilingResult> tilingResult =
326-
scf::tileUsingSCFForOp(rewriter, rootOp, options);
327-
if (failed(tilingResult)) {
328-
return rewriter.notifyMatchFailure(rootOp,
329-
"failed to tile base operation");
318+
for (Operation *op : tiledAndFusedOps) {
319+
yielded[op] = llvm::any_of(op->getUsers(), [&](Operation *user) {
320+
return !isIgnoredUser(user);
321+
});
330322
}
331-
yieldedValuesToOrigValues.append(rootOp->result_begin(),
332-
rootOp->result_end());
333-
334-
// 2. Tiling each operation results in generation of slices. The source of
335-
// these slices could be producers that can be fused into the tiled loops by
336-
// computing the slices of these producers in-place. This results in more
337-
// slices created for operands of the "fused producer". This open up more
338-
// opportunities for fusion. Use a worklist to fuse greedily.
339-
auto addCandidateSlices =
340-
[](Operation *fusedOp, std::deque<tensor::ExtractSliceOp> &candidates) {
341-
for (Value operand : fusedOp->getOperands())
342-
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
343-
candidates.push_back(sliceOp);
344-
};
345323

346-
std::deque<tensor::ExtractSliceOp> candidates;
347-
addCandidateSlices(tilingResult->tiledOps.back(), candidates);
348-
OpBuilder::InsertionGuard g(rewriter);
349-
auto forLoops = llvm::to_vector(llvm::map_range(
350-
tilingResult->loops, [](auto op) { return cast<scf::ForOp>(op); }));
351-
while (!candidates.empty()) {
352-
// Traverse the slices in BFS fashion.
353-
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
354-
candidates.pop_front();
355-
356-
// Materialize the slice of the producer in place.
357-
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
358-
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
359-
if (!fusedProducer)
360-
continue;
361-
362-
// Check if the fused producer has other uses that require the value
363-
// to be yielded from within the tiled loop.
364-
OpResult untiledProducer = fusedProducer->origProducer;
365-
if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
366-
return !isIgnoredUser(user, forLoops.front());
367-
})) {
368-
yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
369-
fusedProducer.value(), forLoops);
370-
yieldedValuesToOrigValues.push_back(untiledProducer);
371-
}
324+
scf::SCFTileAndFuseOptions tileAndFuseOptions;
325+
tileAndFuseOptions.setTilingOptions(options);
326+
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
327+
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
328+
bool isDestinationOperand) {
329+
Operation *owner = originalProducer.getOwner();
330+
return std::make_tuple(true,
331+
yielded.contains(owner) && yielded[owner]);
332+
};
333+
tileAndFuseOptions.setFusionControlFn(controlFn);
372334

373-
// Add more fusion candidates to the worklist.
374-
if (auto fusedProducerOp =
375-
fusedProducer->tiledAndFusedProducer.getDefiningOp())
376-
addCandidateSlices(fusedProducerOp, candidates);
335+
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
336+
scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
337+
rewriter, rootOp, tileAndFuseOptions);
338+
if (failed(tileAndFuseResult)) {
339+
return rewriter.notifyMatchFailure(
340+
rootOp, "failed to tile and fuse with op as root");
377341
}
378342

379-
scf::ForOp outermostLoop = forLoops.front();
380-
for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
381-
Value replacement = outermostLoop.getResult(index);
343+
for (auto it : tileAndFuseResult->replacements) {
344+
Value origVal = it.first;
345+
Value replacement = it.second;
382346
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
383-
return !isIgnoredUser(use.getOwner(), outermostLoop);
347+
Operation *user = use.getOwner();
348+
return !isIgnoredUser(user) &&
349+
!tileAndFuseResult->loops.front()->isAncestor(user);
384350
});
385351
}
352+
386353
rewriter.eraseOp(rootOp);
387-
filter.replaceTransformationFilter(rewriter, tilingResult->tiledOps.back());
354+
for (auto tiledAndFusedOp : tileAndFuseResult->tiledAndFusedOps)
355+
if (tiledAndFusedOp->hasAttr(kTransformMarker))
356+
filter.replaceTransformationFilter(rewriter, tiledAndFusedOp);
357+
388358
return success();
389359
}
390360

0 commit comments

Comments
 (0)