Skip to content

Commit 63086d6

Browse files
[mlir][Interfaces] LoopLikeOpInterface: Add replaceWithAdditionalYields (#67121)
`affine::replaceForOpWithNewYields` and `replaceLoopWithNewYields` (for "scf.for") are now interface methods and additional loop-carried variables can now be added to "scf.for"/"affine.for" uniformly. (No more `TypeSwitch` needed.) Note: `scf.while` and other loops with loop-carried variables can implement `replaceWithAdditionalYields`, but to keep this commit small, that is not done in this commit.
1 parent 19357b4 commit 63086d6

File tree

17 files changed

+241
-256
lines changed

17 files changed

+241
-256
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -490,20 +490,6 @@ void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
490490
function_ref<void(OpBuilder &, Location, ValueRange)>
491491
bodyBuilderFn = nullptr);
492492

493-
/// Replace `loop` with a new loop where `newIterOperands` are appended with
494-
/// new initialization values and `newYieldedValues` are added as new yielded
495-
/// values. The returned ForOp has `newYieldedValues.size()` new result values.
496-
/// Additionally, if `replaceLoopResults` is true, all uses of
497-
/// `loop.getResults()` are replaced with the first `loop.getNumResults()`
498-
/// return values of the original loop respectively. The original loop is
499-
/// deleted and the new loop returned.
500-
/// Prerequisite: `newIterOperands.size() == newYieldedValues.size()`.
501-
AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
502-
ValueRange newIterOperands,
503-
ValueRange newYieldedValues,
504-
ValueRange newIterArgs,
505-
bool replaceLoopResults = true);
506-
507493
/// AffineBound represents a lower or upper bound in the for operation.
508494
/// This class does not own the underlying operands. Instead, it refers
509495
/// to the operands stored in the AffineForOp. Its life span should not exceed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def AffineForOp : Affine_Op<"for",
120120
[AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable,
121121
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
122122
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
123-
"getSingleUpperBound"]>,
123+
"getSingleUpperBound", "replaceWithAdditionalYields"]>,
124124
DeclareOpInterfaceMethods<RegionBranchOpInterface,
125125
["getEntrySuccessorOperands"]>]> {
126126
let summary = "for operation";

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
122122
def ForOp : SCF_Op<"for",
123123
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
124124
["getInits", "getSingleInductionVar", "getSingleLowerBound",
125-
"getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration"]>,
125+
"getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
126+
"replaceWithAdditionalYields"]>,
126127
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
127128
ConditionallySpeculatable,
128129
DeclareOpInterfaceMethods<RegionBranchOpInterface,

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,39 +34,6 @@ class CallOp;
3434
class FuncOp;
3535
} // namespace func
3636

37-
/// Replace the `loop` with `newIterOperands` added as new initialization
38-
/// values. `newYieldValuesFn` is a callback that can be used to specify
39-
/// the additional values to be yielded by the loop. The number of
40-
/// values returned by the callback should match the number of new
41-
/// initialization values. This function
42-
/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created
43-
/// loop
44-
/// - Replaces the uses of `loop` with the new loop.
45-
/// - `loop` isnt erased, but is left in a "no-op" state where the body of the
46-
/// loop just yields the basic block arguments that correspond to the
47-
/// initialization values of a loop. The loop is dead after this method.
48-
/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
49-
/// `newIterOperands` within the generated new loop are replaced
50-
/// with the corresponding `BlockArgument` in the loop body.
51-
using NewYieldValueFn = std::function<SmallVector<Value>(
52-
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
53-
scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
54-
ValueRange newIterOperands,
55-
const NewYieldValueFn &newYieldValuesFn,
56-
bool replaceIterOperandsUsesInLoop = true);
57-
// Simpler API if the new yields are just a list of values that can be
58-
// determined ahead of time.
59-
inline scf::ForOp
60-
replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
61-
ValueRange newIterOperands, ValueRange newYields,
62-
bool replaceIterOperandsUsesInLoop = true) {
63-
auto fn = [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
64-
return SmallVector<Value>(newYields.begin(), newYields.end());
65-
};
66-
return replaceLoopWithNewYields(builder, loop, newIterOperands, fn,
67-
replaceIterOperandsUsesInLoop);
68-
}
69-
7037
/// Update a perfectly nested loop nest to yield new values from the innermost
7138
/// loop and propagating it up through the loop nest. This function
7239
/// - Expects `loopNest` to be a perfectly nested loop with outer most loop
@@ -82,11 +49,10 @@ replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
8249
/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
8350
/// `newIterOperands` within the generated new loop are replaced with the
8451
/// corresponding `BlockArgument` in the loop body.
85-
SmallVector<scf::ForOp>
86-
replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
87-
ValueRange newIterOperands,
88-
const NewYieldValueFn &newYieldValueFn,
89-
bool replaceIterOperandsUsesInLoop = true);
52+
SmallVector<scf::ForOp> replaceLoopNestWithNewYields(
53+
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
54+
ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
55+
bool replaceIterOperandsUsesInLoop = true);
9056

9157
/// Outline a region with a single block into a new FuncOp.
9258
/// Assumes the FuncOp result types is the type of the yielded operands of the

mlir/include/mlir/Interfaces/LoopLikeInterface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
namespace mlir {
1919
class RewriterBase;
20+
21+
/// A function that returns the additional yielded values during
22+
/// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region
23+
/// iter_args. This function should return as many values as there are block
24+
/// arguments in `newBbArgs`.
25+
using NewYieldValuesFn = std::function<SmallVector<Value>(
26+
OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
2027
} // namespace mlir
2128

2229
/// Include the generated interface declarations.

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,31 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
141141
return ::mlir::Block::BlockArgListType();
142142
}]
143143
>,
144+
InterfaceMethod<[{
145+
Append the specified additional "init" operands: replace this loop with
146+
a new loop that has the additional init operands. The loop body of
147+
this loop is moved over to the new loop.
148+
149+
`newInitOperands` specifies the additional "init" operands.
150+
`newYieldValuesFn` is a function that returns the yielded values (which
151+
can be computed based on the additional region iter_args). If
152+
`replaceInitOperandUsesInLoop` is set, all uses of the additional init
153+
operands inside of this loop are replaced with the corresponding, newly
154+
added region iter_args.
155+
156+
Note: Loops that do not support init/iter_args should return "failure".
157+
}],
158+
/*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
159+
/*methodName=*/"replaceWithAdditionalYields",
160+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
161+
"::mlir::ValueRange":$newInitOperands,
162+
"bool":$replaceInitOperandUsesInLoop,
163+
"const ::mlir::NewYieldValuesFn &":$newYieldValuesFn),
164+
/*methodBody=*/"",
165+
/*defaultImplementation=*/[{
166+
return ::mlir::failure();
167+
}]
168+
>,
144169
];
145170

146171
let extraClassDeclaration = [{
@@ -149,6 +174,24 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
149174
/// because the control flow graph is cyclic
150175
static bool blockIsInLoop(Block *block);
151176
}];
177+
178+
let extraSharedClassDeclaration = [{
179+
/// Append the specified additional "init" operands: replace this loop with
180+
/// a new loop that has the additional init operands. The loop body of this
181+
/// loop is moved over to the new loop.
182+
///
183+
/// The newly added region iter_args are yielded from the loop.
184+
::mlir::FailureOr<::mlir::LoopLikeOpInterface>
185+
replaceWithAdditionalIterOperands(::mlir::RewriterBase &rewriter,
186+
::mlir::ValueRange newInitOperands,
187+
bool replaceInitOperandUsesInLoop) {
188+
return $_op.replaceWithAdditionalYields(
189+
rewriter, newInitOperands, replaceInitOperandUsesInLoop,
190+
[](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
191+
return SmallVector<Value>(newBBArgs);
192+
});
193+
}
194+
}];
152195
}
153196

154197
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,6 +2575,58 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
25752575
return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
25762576
}
25772577

2578+
FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2579+
RewriterBase &rewriter, ValueRange newInitOperands,
2580+
bool replaceInitOperandUsesInLoop,
2581+
const NewYieldValuesFn &newYieldValuesFn) {
2582+
// Create a new loop before the existing one, with the extra operands.
2583+
OpBuilder::InsertionGuard g(rewriter);
2584+
rewriter.setInsertionPoint(getOperation());
2585+
auto inits = llvm::to_vector(getInits());
2586+
inits.append(newInitOperands.begin(), newInitOperands.end());
2587+
AffineForOp newLoop = rewriter.create<AffineForOp>(
2588+
getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2589+
getUpperBoundOperands(), getUpperBoundMap(), getStep(), inits);
2590+
2591+
// Generate the new yield values and append them to the scf.yield operation.
2592+
auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2593+
ArrayRef<BlockArgument> newIterArgs =
2594+
newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2595+
{
2596+
OpBuilder::InsertionGuard g(rewriter);
2597+
rewriter.setInsertionPoint(yieldOp);
2598+
SmallVector<Value> newYieldedValues =
2599+
newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2600+
assert(newInitOperands.size() == newYieldedValues.size() &&
2601+
"expected as many new yield values as new iter operands");
2602+
rewriter.updateRootInPlace(yieldOp, [&]() {
2603+
yieldOp.getOperandsMutable().append(newYieldedValues);
2604+
});
2605+
}
2606+
2607+
// Move the loop body to the new op.
2608+
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2609+
newLoop.getBody()->getArguments().take_front(
2610+
getBody()->getNumArguments()));
2611+
2612+
if (replaceInitOperandUsesInLoop) {
2613+
// Replace all uses of `newInitOperands` with the corresponding basic block
2614+
// arguments.
2615+
for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2616+
rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2617+
[&](OpOperand &use) {
2618+
Operation *user = use.getOwner();
2619+
return newLoop->isProperAncestor(user);
2620+
});
2621+
}
2622+
}
2623+
2624+
// Replace the old loop.
2625+
rewriter.replaceOp(getOperation(),
2626+
newLoop->getResults().take_front(getNumResults()));
2627+
return cast<LoopLikeOpInterface>(newLoop.getOperation());
2628+
}
2629+
25782630
Speculation::Speculatability AffineForOp::getSpeculatability() {
25792631
// `affine.for (I = Start; I < End; I += 1)` terminates for all values of
25802632
// Start and End.
@@ -2725,50 +2777,6 @@ void mlir::affine::buildAffineLoopNest(
27252777
buildAffineLoopFromValues);
27262778
}
27272779

2728-
AffineForOp mlir::affine::replaceForOpWithNewYields(OpBuilder &b,
2729-
AffineForOp loop,
2730-
ValueRange newIterOperands,
2731-
ValueRange newYieldedValues,
2732-
ValueRange newIterArgs,
2733-
bool replaceLoopResults) {
2734-
assert(newIterOperands.size() == newYieldedValues.size() &&
2735-
"newIterOperands must be of the same size as newYieldedValues");
2736-
// Create a new loop before the existing one, with the extra operands.
2737-
OpBuilder::InsertionGuard g(b);
2738-
b.setInsertionPoint(loop);
2739-
auto operands = llvm::to_vector<4>(loop.getInits());
2740-
operands.append(newIterOperands.begin(), newIterOperands.end());
2741-
SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
2742-
SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
2743-
SmallVector<Value, 4> steps(loop.getStep());
2744-
auto lbMap = loop.getLowerBoundMap();
2745-
auto ubMap = loop.getUpperBoundMap();
2746-
AffineForOp newLoop =
2747-
b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
2748-
loop.getStep(), operands);
2749-
// Take the body of the original parent loop.
2750-
newLoop.getRegion().takeBody(loop.getRegion());
2751-
for (Value val : newIterArgs)
2752-
newLoop.getRegion().addArgument(val.getType(), val.getLoc());
2753-
2754-
// Update yield operation with new values to be added.
2755-
if (!newYieldedValues.empty()) {
2756-
auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
2757-
b.setInsertionPoint(yield);
2758-
auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
2759-
yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
2760-
b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
2761-
yield.erase();
2762-
}
2763-
if (replaceLoopResults) {
2764-
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
2765-
loop.getNumResults()))) {
2766-
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
2767-
}
2768-
}
2769-
return newLoop;
2770-
}
2771-
27722780
//===----------------------------------------------------------------------===//
27732781
// AffineIfOp
27742782
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Affine/LoopUtils.h"
2020
#include "mlir/IR/IRMapping.h"
2121
#include "mlir/IR/Operation.h"
22+
#include "mlir/IR/PatternMatch.h"
2223
#include "llvm/Support/Debug.h"
2324
#include "llvm/Support/raw_ostream.h"
2425
#include <optional>
@@ -361,16 +362,22 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
361362
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
362363
if (!tripCount || *tripCount != 1)
363364
return failure();
364-
auto iterOperands = forOp.getInits();
365365
auto *parentOp = forOp->getParentOp();
366366
if (!isa<AffineForOp>(parentOp))
367367
return failure();
368-
auto newOperands = forOp.getBody()->getTerminator()->getOperands();
369-
OpBuilder b(parentOp);
368+
SmallVector<Value> newOperands;
369+
llvm::append_range(newOperands,
370+
forOp.getBody()->getTerminator()->getOperands());
371+
IRRewriter rewriter(parentOp->getContext());
372+
int64_t parentOpNumResults = parentOp->getNumResults();
370373
// Replace the parent loop and add iteroperands and results from the `forOp`.
371374
AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
372-
AffineForOp newLoop = replaceForOpWithNewYields(
373-
b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
375+
AffineForOp newLoop =
376+
cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
377+
rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
378+
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
379+
return newOperands;
380+
}));
374381

375382
// For sibling-fusion users, collect operations that use the results of the
376383
// `forOp` outside the new parent loop that has absorbed all its iter args
@@ -387,7 +394,7 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
387394
// Update the results of the `forOp` in the new loop.
388395
for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
389396
forOp.getResult(i).replaceAllUsesWith(
390-
newLoop.getResult(i + parentOp->getNumResults()));
397+
newLoop.getResult(i + parentOpNumResults));
391398
}
392399
// For sibling-fusion users, move operations that use the results of the
393400
// `forOp` outside the new parent loop
@@ -412,7 +419,6 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
412419
parentBlock->getOperations().splice(Block::iterator(forOp),
413420
forOp.getBody()->getOperations());
414421
forOp.erase();
415-
parentForOp.erase();
416422
return success();
417423
}
418424

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,9 +1197,9 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11971197
// `unrollJamFactor` copies of its iterOperands, iter_args and yield
11981198
// operands.
11991199
SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
1200-
OpBuilder builder(forOp.getContext());
1200+
IRRewriter rewriter(forOp.getContext());
12011201
for (AffineForOp oldForOp : loopsWithIterArgs) {
1202-
SmallVector<Value, 4> dupIterOperands, dupIterArgs, dupYieldOperands;
1202+
SmallVector<Value> dupIterOperands, dupYieldOperands;
12031203
ValueRange oldIterOperands = oldForOp.getInits();
12041204
ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
12051205
ValueRange oldYieldOperands =
@@ -1208,19 +1208,21 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
12081208
// fix iterOperands and yield operands after cloning of sub-blocks.
12091209
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
12101210
dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
1211-
dupIterArgs.append(oldIterArgs.begin(), oldIterArgs.end());
12121211
dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
12131212
}
12141213
// Create a new loop with additional iterOperands, iter_args and yield
12151214
// operands. This new loop will take the loop body of the original loop.
1216-
AffineForOp newForOp = affine::replaceForOpWithNewYields(
1217-
builder, oldForOp, dupIterOperands, dupYieldOperands, dupIterArgs);
1215+
bool forOpReplaced = oldForOp == forOp;
1216+
AffineForOp newForOp =
1217+
cast<AffineForOp>(*oldForOp.replaceWithAdditionalYields(
1218+
rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
1219+
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
1220+
return dupYieldOperands;
1221+
}));
12181222
newLoopsWithIterArgs.push_back(newForOp);
12191223
// `forOp` has been replaced with a new loop.
1220-
if (oldForOp == forOp)
1224+
if (forOpReplaced)
12211225
forOp = newForOp;
1222-
assert(oldForOp.use_empty() && "old for op should not have any user");
1223-
oldForOp.erase();
12241226
// Update `operandMaps` for `newForOp` iterArgs and results.
12251227
ValueRange newIterArgs = newForOp.getRegionIterArgs();
12261228
unsigned oldNumIterArgs = oldIterArgs.size();
@@ -1294,7 +1296,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
12941296
// into one value. For example, for %0:2 = affine.for ... and addf, we add
12951297
// %1 = arith.addf %0#0, %0#1, and replace the following uses of %0#0 with
12961298
// %1.
1297-
builder.setInsertionPointAfter(forOp);
1299+
rewriter.setInsertionPointAfter(forOp);
12981300
auto loc = forOp.getLoc();
12991301
unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
13001302
for (LoopReduction &reduction : reductions) {
@@ -1305,7 +1307,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
13051307
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
13061308
rhs = forOp.getResult(i * oldNumResults + pos);
13071309
// Create ops based on reduction type.
1308-
lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
1310+
lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
13091311
if (!lhs)
13101312
return failure();
13111313
Operation *op = lhs.getDefiningOp();

0 commit comments

Comments
 (0)