Skip to content

[SCF][Transform] Add support for scf.for in LoopFuseSibling op #81495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -333,23 +333,24 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
}];
}

def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Fuse a loop into another loop, assuming the fusion is legal.";

let description = [{
Fuses the `target` loop into the `source` loop assuming they are
independent of each other. It is the responsibility of the user to ensure
that the given two loops are independent of each other, this operation will
not performa any legality checks and will simply fuse the two given loops.
independent of each other. In the fused loop, the arguments, body and
results of `target` are placed _before_ those of `source`.

Currently, the only fusion supported is when both `target` and `source`
are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
mapping must match, otherwise a silencable failure is produced.
For fusion of two `scf.for` loops, the bounds and step size must match. For
fusion of two `scf.forall` loops, the bounds and the mapping must match.
Otherwise a silencable failure is produced.

The input handles `target` and `source` must map to exactly one operation,
a definite failure is produced otherwise.
The `target` and `source` handles must refer to exactly one operation,
otherwise a definite failure is produced. It is the responsibility of the
user to ensure that the `target` and `source` loops are independent of each
other -- this op will only perform rudimentary legality checks.

#### Return modes

Expand All @@ -362,10 +363,6 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
let results = (outs TransformHandleTypeInterface:$fused_loop);
let assemblyFormat = "$target `into` $source attr-dict "
" `:` functional-type(operands, results)";

let builders = [
OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
];
}

#endif // SCF_TRANSFORM_OPS
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter);

/// Given two scf.for loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
///
/// This function does not perform any legality checks and simply fuses the
/// loops. The caller is responsible for ensuring that the loops are legal to
/// fuse.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
RewriterBase &rewriter);

} // namespace mlir

#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
66 changes: 39 additions & 27 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
}

//===----------------------------------------------------------------------===//
// LoopFuseSibling
// LoopFuseSiblingOp
//===----------------------------------------------------------------------===//

/// Check if `target` and `source` are siblings, in the context that `target`
Expand All @@ -408,7 +408,7 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
// Check if fusion will violate dominance.
DominanceInfo domInfo(source);
if (target->isBeforeInBlock(source)) {
// Since, `target` is before `source`, all users of results of `target`
// Since `target` is before `source`, all users of results of `target`
// need to be dominated by `source`.
for (Operation *user : target->getUsers()) {
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
Expand All @@ -424,9 +424,8 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
// Check if operands of `target` are dominated by `source`.
for (Value operand : target->getOperands()) {
Operation *operandOp = operand.getDefiningOp();
// If operand does not have a defining operation, it is a block arguement,
// which will always dominate `source`, since `target` and `source` are in
// the same block and the operand dominated `source` before.
// Operands without defining operations are block arguments. When `target`
// and `source` occur in the same block, these operands dominate `source`.
if (!operandOp)
continue;

Expand All @@ -441,8 +440,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
bool failed = false;
OpOperand *failedValue = nullptr;
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
if (!domInfo.properlyDominates(operand->getOwner(), source,
/*enclosingOpOk=*/false)) {
Operation *operandOp = operand->get().getDefiningOp();
if (operandOp && !domInfo.properlyDominates(operandOp, source,
/*enclosingOpOk=*/false)) {
// `operand` is not an argument of an enclosing block and the defining
// op of `operand` is outside `target` but does not dominate `source`.
failed = true;
failedValue = operand;
}
Expand All @@ -457,12 +459,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
return DiagnosedSilenceableFailure::success();
}

/// Check if `target` can be fused into `source`.
/// Check if `target` scf.forall can be fused into `source` scf.forall.
///
/// This is a simple check that just checks if both loops have same
/// bounds, steps and mapping. This check does not ensure that the side effects
/// of `target` are independent of `source` or vice-versa. It is the
/// responsibility of the caller to ensure that.
/// This simply checks if both loops have the same bounds, steps and mapping.
/// No attempt is made at checking that the side effects of `target` and
/// `source` are independent of each other.
static bool isForallWithIdenticalConfiguration(Operation *target,
Operation *source) {
auto targetOp = dyn_cast<scf::ForallOp>(target);
Expand All @@ -476,21 +477,27 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
targetOp.getMapping() == sourceOp.getMapping();
}

/// Fuse `target` into `source` assuming they are siblings and indepndent.
/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
static Operation *fuseSiblings(Operation *target, Operation *source,
RewriterBase &rewriter) {
auto targetOp = dyn_cast<scf::ForallOp>(target);
auto sourceOp = dyn_cast<scf::ForallOp>(source);
/// Check if `target` scf.for can be fused into `source` scf.for.
///
/// This simply checks if both loops have the same bounds and steps. No attempt
/// is made at checking that the side effects of `target` and `source` are
/// independent of each other.
static bool isForWithIdenticalConfiguration(Operation *target,
Operation *source) {
auto targetOp = dyn_cast<scf::ForOp>(target);
auto sourceOp = dyn_cast<scf::ForOp>(source);
if (!targetOp || !sourceOp)
return nullptr;
return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
return false;

return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
targetOp.getStep() == sourceOp.getStep();
}

DiagnosedSilenceableFailure
transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
auto sourceOps = state.getPayloadOps(getSource());

Expand All @@ -510,13 +517,18 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
if (!diag.succeeded())
return diag;

// Check if the target can be fused into source.
if (!isForallWithIdenticalConfiguration(target, source)) {
Operation *fusedLoop;
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
if (isForWithIdenticalConfiguration(target, source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
} else if (isForallWithIdenticalConfiguration(target, source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
<< "operations cannot be fused";
}

Operation *fusedLoop = fuseSiblings(target, source, rewriter);
assert(fusedLoop && "failed to fuse operations");

results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
Expand Down
99 changes: 68 additions & 31 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();

OperandRange targetOuts = target.getOutputs();
OperandRange sourceOuts = source.getOutputs();

// Create fused shared_outs.
SmallVector<Value> fusedOuts;
fusedOuts.reserve(numTargetOuts + numSourceOuts);
fusedOuts.append(targetOuts.begin(), targetOuts.end());
fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
llvm::append_range(fusedOuts, target.getOutputs());
llvm::append_range(fusedOuts, source.getOutputs());

// Create a new scf::forall op after the source loop.
// Create a new scf.forall op after the source loop.
rewriter.setInsertionPointAfter(source);
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
source.getMixedStep(), fusedOuts, source.getMapping());

// Map control operands.
IRMapping fusedMapping;
fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
IRMapping mapping;
mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());

// Map shared outs.
fusedMapping.map(target.getRegionIterArgs(),
fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
fusedMapping.map(
source.getRegionIterArgs(),
fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
mapping.map(target.getRegionIterArgs(),
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
mapping.map(source.getRegionIterArgs(),
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));

// Append everything except the terminator into the fused operation.
rewriter.setInsertionPointToStart(fusedLoop.getBody());
for (Operation &op : target.getBody()->without_terminator())
rewriter.clone(op, fusedMapping);
rewriter.clone(op, mapping);
for (Operation &op : source.getBody()->without_terminator())
rewriter.clone(op, fusedMapping);
rewriter.clone(op, mapping);

// Fuse the old terminator in_parallel ops into the new one.
scf::InParallelOp targetTerm = target.getTerminator();
scf::InParallelOp sourceTerm = source.getTerminator();
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();

rewriter.setInsertionPointToStart(fusedTerm.getBody());
for (Operation &op : targetTerm.getYieldingOps())
rewriter.clone(op, fusedMapping);
rewriter.clone(op, mapping);
for (Operation &op : sourceTerm.getYieldingOps())
rewriter.clone(op, fusedMapping);

// Replace all uses of the old loops with the fused loop.
rewriter.replaceAllUsesWith(target.getResults(),
fusedLoop.getResults().slice(0, numTargetOuts));
rewriter.replaceAllUsesWith(
source.getResults(),
fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));

// Erase the old loops.
rewriter.eraseOp(target);
rewriter.eraseOp(source);
rewriter.clone(op, mapping);

// Replace old loops by substituting their uses by results of the fused loop.
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));

return fusedLoop;
}

scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();

// Create fused init_args, with target's init_args before source's init_args.
SmallVector<Value> fusedInitArgs;
llvm::append_range(fusedInitArgs, target.getInitArgs());
llvm::append_range(fusedInitArgs, source.getInitArgs());

// Create a new scf.for op after the source loop (with scf.yield terminator
// (without arguments) only in case its init_args is empty).
rewriter.setInsertionPointAfter(source);
scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
source.getLoc(), source.getLowerBound(), source.getUpperBound(),
source.getStep(), fusedInitArgs);

// Map original induction variables and operands to those of the fused loop.
IRMapping mapping;
mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
mapping.map(target.getRegionIterArgs(),
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
mapping.map(source.getRegionIterArgs(),
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));

// Merge target's body into the new (fused) for loop and then source's body.
rewriter.setInsertionPointToStart(fusedLoop.getBody());
for (Operation &op : target.getBody()->without_terminator())
rewriter.clone(op, mapping);
for (Operation &op : source.getBody()->without_terminator())
rewriter.clone(op, mapping);

// Build fused yield results by appropriately mapping original yield operands.
SmallVector<Value> yieldResults;
for (Value operand : target.getBody()->getTerminator()->getOperands())
yieldResults.push_back(mapping.lookupOrDefault(operand));
for (Value operand : source.getBody()->getTerminator()->getOperands())
yieldResults.push_back(mapping.lookupOrDefault(operand));
if (!yieldResults.empty())
rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);

// Replace old loops by substituting their uses by results of the fused loop.
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));

return fusedLoop;
}
Loading