Skip to content

[mlir][TilingInterface] Allow controlling what fusion is done within tile and fuse #76871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,30 @@ struct SCFTileAndFuseOptions {
tilingOptions = options;
return *this;
}

/// Control function to check if a slice needs to be fused or not,
/// The control function receives
/// 1) the slice along which fusion is to be done,
/// 2) the producer value that is to be fused
/// 3) a boolean value set to `true` if the fusion is from
/// a destination operand.
/// It retuns two booleans
/// - returns `true` if the fusion should be done through the candidate slice
/// - returns `true` if a replacement for the fused producer needs to be
/// yielded from within the tiled loop. Note that it is valid to return
/// `true` only if the slice fused is disjoint across all iterations of the
/// tiled loop. It is up to the caller to ensure that this is true for the
/// fused producers.
using ControlFnTy = std::function<std::tuple<bool, bool>(
tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand)>;
ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
return std::make_tuple(true, false);
};
SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
fusionControlFn = controlFn;
return *this;
}
};

/// Fuse the producer of the source of `candidateSliceOp` by computing the
Expand Down
65 changes: 47 additions & 18 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,32 +728,36 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
}

// 1. First tile the consumer.
SmallVector<scf::ForOp> forLoops;
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
DenseMap<Value, Value> replacements;
llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
{
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
for (auto *tiledOp : tilingResult->tiledOps)
tiledAndFusedOps.insert(tiledOp);
forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
for (auto [index, origValue, replacement] :
llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
replacements[origValue] = replacement;
yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
index)] = index;
}
}
llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
for (auto *tiledOp : tilingResult->tiledOps)
tiledAndFusedOps.insert(tiledOp);
SmallVector<scf::ForOp> forLoops =
castToTypedOperations<scf::ForOp>(tilingResult->loops);

// If there are no loops generated, fusion is immaterial.
if (forLoops.empty()) {
DenseMap<Value, Value> replacements;
for (auto [origVal, replacement] :
llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
replacements[origVal] = replacement;
}
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
getAsOperations(forLoops), replacements};
}

// To keep track of replacements for now just record the map from the original
// untiled value to the result number of the for loop. Since the loop gets
// potentially replaced during fusion, keeping the value directly wont work.
DenseMap<Value, size_t> origValToResultNumber;
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
origValToResultNumber[result] = index;
}

// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
// `tensor.extract_slice` operations with source being the operands of the
Expand All @@ -776,6 +780,18 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
candidates.pop_front();

// Find the original producer of the slice.
auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
forLoops);
if (!fusableProducer)
continue;

auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
candidateSliceOp, fusableProducer, destinationInitArg.has_value());
if (!fuseSlice)
continue;

// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
Expand All @@ -784,13 +800,26 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
if (!fusedResult)
continue;

if (yieldReplacement) {
yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
fusedResult.value(), forLoops);
origValToResultNumber[fusableProducer] =
forLoops.front().getNumResults() - 1;
}

if (Operation *tiledAndFusedOp =
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
addCandidateSlices(tiledAndFusedOp, candidates);
}
}

DenseMap<Value, Value> replacements;
for (auto [origVal, resultNumber] : origValToResultNumber) {
replacements[origVal] = forLoops.front()->getResult(resultNumber);
}

return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
getAsOperations(forLoops), replacements};
}
Expand Down
98 changes: 34 additions & 64 deletions mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,80 +311,50 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
// Collect list of operations that can be tiled and fused.
llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
collectTiledAndFusedOps(rootOp);
auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user) ||
outerMostTiledLoop->isAncestor(user);
llvm::SmallDenseMap<Operation *, bool> yielded;
auto isIgnoredUser = [&](Operation *user) {
return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user);
};

// The rest of this method is similar to
// scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp, except that also
// yields replacements for values of the fused producer.

// 1. Tile the consumer.
SmallVector<OpResult> yieldedValuesToOrigValues;
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCFForOp(rewriter, rootOp, options);
if (failed(tilingResult)) {
return rewriter.notifyMatchFailure(rootOp,
"failed to tile base operation");
for (Operation *op : tiledAndFusedOps) {
yielded[op] = llvm::any_of(op->getUsers(), [&](Operation *user) {
return !isIgnoredUser(user);
});
}
yieldedValuesToOrigValues.append(rootOp->result_begin(),
rootOp->result_end());

// 2. Tiling each operation results in generation of slices. The source of
// these slices could be producers that can be fused into the tiled loops by
// computing the slices of these producers in-place. This results in more
// slices created for operands of the "fused producer". This open up more
// opportunities for fusion. Use a worklist to fuse greedily.
auto addCandidateSlices =
[](Operation *fusedOp, std::deque<tensor::ExtractSliceOp> &candidates) {
for (Value operand : fusedOp->getOperands())
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
candidates.push_back(sliceOp);
};

std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tilingResult->tiledOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
auto forLoops = llvm::to_vector(llvm::map_range(
tilingResult->loops, [](auto op) { return cast<scf::ForOp>(op); }));
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
candidates.pop_front();

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
if (!fusedProducer)
continue;

// Check if the fused producer has other uses that require the value
// to be yielded from within the tiled loop.
OpResult untiledProducer = fusedProducer->origProducer;
if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
return !isIgnoredUser(user, forLoops.front());
})) {
yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
fusedProducer.value(), forLoops);
yieldedValuesToOrigValues.push_back(untiledProducer);
}
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(options);
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand) {
Operation *owner = originalProducer.getOwner();
return std::make_tuple(true,
yielded.contains(owner) && yielded[owner]);
};
tileAndFuseOptions.setFusionControlFn(controlFn);

// Add more fusion candidates to the worklist.
if (auto fusedProducerOp =
fusedProducer->tiledAndFusedProducer.getDefiningOp())
addCandidateSlices(fusedProducerOp, candidates);
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, rootOp, tileAndFuseOptions);
if (failed(tileAndFuseResult)) {
return rewriter.notifyMatchFailure(
rootOp, "failed to tile and fuse with op as root");
}

scf::ForOp outermostLoop = forLoops.front();
for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
Value replacement = outermostLoop.getResult(index);
for (auto it : tileAndFuseResult->replacements) {
Value origVal = it.first;
Value replacement = it.second;
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
return !isIgnoredUser(use.getOwner(), outermostLoop);
Operation *user = use.getOwner();
return !isIgnoredUser(user) &&
!tileAndFuseResult->loops.front()->isAncestor(user);
});
}

rewriter.eraseOp(rootOp);
filter.replaceTransformationFilter(rewriter, tilingResult->tiledOps.back());
for (auto tiledAndFusedOp : tileAndFuseResult->tiledAndFusedOps)
if (tiledAndFusedOp->hasAttr(kTransformMarker))
filter.replaceTransformationFilter(rewriter, tiledAndFusedOp);

return success();
}

Expand Down