Skip to content

[mlir][Interfaces] LoopLikeOpInterface: Add replaceWithAdditionalYields #67121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,20 +490,6 @@ void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
function_ref<void(OpBuilder &, Location, ValueRange)>
bodyBuilderFn = nullptr);

/// Replace `loop` with a new loop where `newIterOperands` are appended with
/// new initialization values and `newYieldedValues` are added as new yielded
/// values. The returned ForOp has `newYieldedValues.size()` new result values.
/// Additionally, if `replaceLoopResults` is true, all uses of
/// `loop.getResults()` are replaced with the first `loop.getNumResults()`
/// return values of the original loop respectively. The original loop is
/// deleted and the new loop returned.
/// Prerequisite: `newIterOperands.size() == newYieldedValues.size()`.
AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
ValueRange newIterOperands,
ValueRange newYieldedValues,
ValueRange newIterArgs,
bool replaceLoopResults = true);

/// AffineBound represents a lower or upper bound in the for operation.
/// This class does not own the underlying operands. Instead, it refers
/// to the operands stored in the AffineForOp. Its life span should not exceed
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def AffineForOp : Affine_Op<"for",
[AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
"getSingleUpperBound"]>,
"getSingleUpperBound", "replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
let summary = "for operation";
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInits", "getSingleInductionVar", "getSingleLowerBound",
"getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration"]>,
"getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
"replaceWithAdditionalYields"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
Expand Down
42 changes: 4 additions & 38 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,6 @@ class CallOp;
class FuncOp;
} // namespace func

/// Replace the `loop` with `newIterOperands` added as new initialization
/// values. `newYieldValuesFn` is a callback that can be used to specify
/// the additional values to be yielded by the loop. The number of
/// values returned by the callback should match the number of new
/// initialization values. This function
/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created
/// loop
/// - Replaces the uses of `loop` with the new loop.
/// - `loop` isnt erased, but is left in a "no-op" state where the body of the
/// loop just yields the basic block arguments that correspond to the
/// initialization values of a loop. The loop is dead after this method.
/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
/// `newIterOperands` within the generated new loop are replaced
/// with the corresponding `BlockArgument` in the loop body.
using NewYieldValueFn = std::function<SmallVector<Value>(
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
ValueRange newIterOperands,
const NewYieldValueFn &newYieldValuesFn,
bool replaceIterOperandsUsesInLoop = true);
// Simpler API if the new yields are just a list of values that can be
// determined ahead of time.
inline scf::ForOp
replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
ValueRange newIterOperands, ValueRange newYields,
bool replaceIterOperandsUsesInLoop = true) {
auto fn = [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>(newYields.begin(), newYields.end());
};
return replaceLoopWithNewYields(builder, loop, newIterOperands, fn,
replaceIterOperandsUsesInLoop);
}

/// Update a perfectly nested loop nest to yield new values from the innermost
/// loop and propagating it up through the loop nest. This function
/// - Expects `loopNest` to be a perfectly nested loop with outer most loop
Expand All @@ -82,11 +49,10 @@ replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
/// `newIterOperands` within the generated new loop are replaced with the
/// corresponding `BlockArgument` in the loop body.
SmallVector<scf::ForOp>
replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
ValueRange newIterOperands,
const NewYieldValueFn &newYieldValueFn,
bool replaceIterOperandsUsesInLoop = true);
SmallVector<scf::ForOp> replaceLoopNestWithNewYields(
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
bool replaceIterOperandsUsesInLoop = true);

/// Outline a region with a single block into a new FuncOp.
/// Assumes the FuncOp result types is the type of the yielded operands of the
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@

namespace mlir {
class RewriterBase;

/// A function that returns the additional yielded values during
/// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region
/// iter_args. This function should return as many values as there are block
/// arguments in `newBbArgs`.
using NewYieldValuesFn = std::function<SmallVector<Value>(
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
} // namespace mlir

/// Include the generated interface declarations.
Expand Down
43 changes: 43 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,31 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return ::mlir::Block::BlockArgListType();
}]
>,
InterfaceMethod<[{
Append the specified additional "init" operands: replace this loop with
a new loop that has the additional init operands. The loop body of
this loop is moved over to the new loop.

`newInitOperands` specifies the additional "init" operands.
`newYieldValuesFn` is a function that returns the yielded values (which
can be computed based on the additional region iter_args). If
`replaceInitOperandUsesInLoop` is set, all uses of the additional init
operands inside of this loop are replaced with the corresponding, newly
added region iter_args.

Note: Loops that do not support init/iter_args should return "failure".
}],
/*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
/*methodName=*/"replaceWithAdditionalYields",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"::mlir::ValueRange":$newInitOperands,
"bool":$replaceInitOperandUsesInLoop,
"const ::mlir::NewYieldValuesFn &":$newYieldValuesFn),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
>,
];

let extraClassDeclaration = [{
Expand All @@ -149,6 +174,24 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// because the control flow graph is cyclic
static bool blockIsInLoop(Block *block);
}];

let extraSharedClassDeclaration = [{
/// Append the specified additional "init" operands: replace this loop with
/// a new loop that has the additional init operands. The loop body of this
/// loop is moved over to the new loop.
///
/// The newly added region iter_args are yielded from the loop.
::mlir::FailureOr<::mlir::LoopLikeOpInterface>
replaceWithAdditionalIterOperands(::mlir::RewriterBase &rewriter,
::mlir::ValueRange newInitOperands,
bool replaceInitOperandUsesInLoop) {
return $_op.replaceWithAdditionalYields(
rewriter, newInitOperands, replaceInitOperandUsesInLoop,
[](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>(newBBArgs);
});
}
}];
}

#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE
96 changes: 52 additions & 44 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2575,6 +2575,58 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
}

FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
RewriterBase &rewriter, ValueRange newInitOperands,
bool replaceInitOperandUsesInLoop,
const NewYieldValuesFn &newYieldValuesFn) {
// Create a new loop before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(getOperation());
auto inits = llvm::to_vector(getInits());
inits.append(newInitOperands.begin(), newInitOperands.end());
AffineForOp newLoop = rewriter.create<AffineForOp>(
getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
getUpperBoundOperands(), getUpperBoundMap(), getStep(), inits);

// Generate the new yield values and append them to the scf.yield operation.
auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
ArrayRef<BlockArgument> newIterArgs =
newLoop.getBody()->getArguments().take_back(newInitOperands.size());
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> newYieldedValues =
newYieldValuesFn(rewriter, getLoc(), newIterArgs);
assert(newInitOperands.size() == newYieldedValues.size() &&
"expected as many new yield values as new iter operands");
rewriter.updateRootInPlace(yieldOp, [&]() {
yieldOp.getOperandsMutable().append(newYieldedValues);
});
}

// Move the loop body to the new op.
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
newLoop.getBody()->getArguments().take_front(
getBody()->getNumArguments()));

if (replaceInitOperandUsesInLoop) {
// Replace all uses of `newInitOperands` with the corresponding basic block
// arguments.
for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
[&](OpOperand &use) {
Operation *user = use.getOwner();
return newLoop->isProperAncestor(user);
});
}
}

// Replace the old loop.
rewriter.replaceOp(getOperation(),
newLoop->getResults().take_front(getNumResults()));
return cast<LoopLikeOpInterface>(newLoop.getOperation());
}

Speculation::Speculatability AffineForOp::getSpeculatability() {
// `affine.for (I = Start; I < End; I += 1)` terminates for all values of
// Start and End.
Expand Down Expand Up @@ -2725,50 +2777,6 @@ void mlir::affine::buildAffineLoopNest(
buildAffineLoopFromValues);
}

AffineForOp mlir::affine::replaceForOpWithNewYields(OpBuilder &b,
AffineForOp loop,
ValueRange newIterOperands,
ValueRange newYieldedValues,
ValueRange newIterArgs,
bool replaceLoopResults) {
assert(newIterOperands.size() == newYieldedValues.size() &&
"newIterOperands must be of the same size as newYieldedValues");
// Create a new loop before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getInits());
operands.append(newIterOperands.begin(), newIterOperands.end());
SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
SmallVector<Value, 4> steps(loop.getStep());
auto lbMap = loop.getLowerBoundMap();
auto ubMap = loop.getUpperBoundMap();
AffineForOp newLoop =
b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
loop.getStep(), operands);
// Take the body of the original parent loop.
newLoop.getRegion().takeBody(loop.getRegion());
for (Value val : newIterArgs)
newLoop.getRegion().addArgument(val.getType(), val.getLoc());

// Update yield operation with new values to be added.
if (!newYieldedValues.empty()) {
auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
b.setInsertionPoint(yield);
auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
yield.erase();
}
if (replaceLoopResults) {
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
loop.getNumResults()))) {
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
}
}
return newLoop;
}

//===----------------------------------------------------------------------===//
// AffineIfOp
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 13 additions & 7 deletions mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
Expand Down Expand Up @@ -361,16 +362,22 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount || *tripCount != 1)
return failure();
auto iterOperands = forOp.getInits();
auto *parentOp = forOp->getParentOp();
if (!isa<AffineForOp>(parentOp))
return failure();
auto newOperands = forOp.getBody()->getTerminator()->getOperands();
OpBuilder b(parentOp);
SmallVector<Value> newOperands;
llvm::append_range(newOperands,
forOp.getBody()->getTerminator()->getOperands());
IRRewriter rewriter(parentOp->getContext());
int64_t parentOpNumResults = parentOp->getNumResults();
// Replace the parent loop and add iteroperands and results from the `forOp`.
AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
AffineForOp newLoop = replaceForOpWithNewYields(
b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
AffineForOp newLoop =
cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
return newOperands;
}));

// For sibling-fusion users, collect operations that use the results of the
// `forOp` outside the new parent loop that has absorbed all its iter args
Expand All @@ -387,7 +394,7 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
// Update the results of the `forOp` in the new loop.
for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
forOp.getResult(i).replaceAllUsesWith(
newLoop.getResult(i + parentOp->getNumResults()));
newLoop.getResult(i + parentOpNumResults));
}
// For sibling-fusion users, move operations that use the results of the
// `forOp` outside the new parent loop
Expand All @@ -412,7 +419,6 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
parentBlock->getOperations().splice(Block::iterator(forOp),
forOp.getBody()->getOperations());
forOp.erase();
parentForOp.erase();
return success();
}

Expand Down
22 changes: 12 additions & 10 deletions mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1197,9 +1197,9 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// `unrollJamFactor` copies of its iterOperands, iter_args and yield
// operands.
SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
OpBuilder builder(forOp.getContext());
IRRewriter rewriter(forOp.getContext());
for (AffineForOp oldForOp : loopsWithIterArgs) {
SmallVector<Value, 4> dupIterOperands, dupIterArgs, dupYieldOperands;
SmallVector<Value> dupIterOperands, dupYieldOperands;
ValueRange oldIterOperands = oldForOp.getInits();
ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
ValueRange oldYieldOperands =
Expand All @@ -1208,19 +1208,21 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// fix iterOperands and yield operands after cloning of sub-blocks.
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
dupIterArgs.append(oldIterArgs.begin(), oldIterArgs.end());
dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
}
// Create a new loop with additional iterOperands, iter_args and yield
// operands. This new loop will take the loop body of the original loop.
AffineForOp newForOp = affine::replaceForOpWithNewYields(
builder, oldForOp, dupIterOperands, dupYieldOperands, dupIterArgs);
bool forOpReplaced = oldForOp == forOp;
AffineForOp newForOp =
cast<AffineForOp>(*oldForOp.replaceWithAdditionalYields(
rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
return dupYieldOperands;
}));
newLoopsWithIterArgs.push_back(newForOp);
// `forOp` has been replaced with a new loop.
if (oldForOp == forOp)
if (forOpReplaced)
forOp = newForOp;
assert(oldForOp.use_empty() && "old for op should not have any user");
oldForOp.erase();
// Update `operandMaps` for `newForOp` iterArgs and results.
ValueRange newIterArgs = newForOp.getRegionIterArgs();
unsigned oldNumIterArgs = oldIterArgs.size();
Expand Down Expand Up @@ -1294,7 +1296,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// into one value. For example, for %0:2 = affine.for ... and addf, we add
// %1 = arith.addf %0#0, %0#1, and replace the following uses of %0#0 with
// %1.
builder.setInsertionPointAfter(forOp);
rewriter.setInsertionPointAfter(forOp);
auto loc = forOp.getLoc();
unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
for (LoopReduction &reduction : reductions) {
Expand All @@ -1305,7 +1307,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
rhs = forOp.getResult(i * oldNumResults + pos);
// Create ops based on reduction type.
lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
if (!lhs)
return failure();
Operation *op = lhs.getDefiningOp();
Expand Down
Loading