Skip to content

Commit 3835681

Browse files
committed
review comments
1 parent 854e4d1 commit 3835681

File tree

4 files changed

+50
-27
lines changed

4 files changed

+50
-27
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,16 @@ namespace linalg {
4343
/// when used on distributed loops with memref semantics!
4444
void hoistRedundantVectorTransfers(Operation *root);
4545

46-
void hoistRedundantVectorBroadcasts(Operation *root);
46+
/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
47+
/// scf::ForOp iteratively, if the following conditions are met:
48+
/// 1. The vector.extract operation is applied on an iter_argument, and no
49+
/// other operator is using this argument in the body of the loop.
50+
/// 2. The position of the vector.extract is either a static value, or defined
51+
/// outside of the loop.
52+
/// 3. The vector.broadcast operation is yielded by the loop.
53+
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
54+
/// function on the candidate loop above which to hoist.
55+
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root);
4756

4857
} // namespace linalg
4958
} // namespace mlir

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3315,7 +3315,8 @@ transform::HoistRedundantVectorBroadcastsOp::applyToOne(
33153315
transform::TransformRewriter &rewriter, mlir::Operation *target,
33163316
transform::ApplyToEachResultList &results,
33173317
transform::TransformState &state) {
3318-
linalg::hoistRedundantVectorBroadcasts(target);
3318+
rewriter.setInsertionPoint(target);
3319+
linalg::hoistRedundantVectorBroadcasts(rewriter, target);
33193320
results.push_back(target);
33203321
return DiagnosedSilenceableFailure::success();
33213322
}

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,40 @@ using llvm::dbgs;
4343
using namespace mlir;
4444
using namespace mlir::linalg;
4545

46-
scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
47-
Value newInitOperand, int index,
48-
Value newYieldValue) {
46+
/// Replace `loop` with a new loop that has a different init operand at
47+
/// position `index`. The body of this loop is moved over to the new loop.
48+
///
49+
/// `newInitOperands` specifies the replacement "init" operands.
50+
/// `newYieldValue` is the replacement yield value of the loop at position
51+
/// `index`.
52+
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
53+
scf::ForOp loop,
54+
Value newInitOperand,
55+
unsigned index,
56+
Value newYieldValue) {
4957
OpBuilder::InsertionGuard g(rewriter);
5058
rewriter.setInsertionPoint(loop.getOperation());
5159
auto inits = llvm::to_vector(loop.getInits());
5260

53-
// Replace the init value with the new operand
61+
// Replace the init value with the new operand.
62+
assert(index < inits.size());
5463
inits[index] = newInitOperand;
5564

5665
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
5766
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
5867
inits, [](OpBuilder &, Location, Value, ValueRange) {});
5968

60-
// Generate the new yield with the replaced operand
69+
// Generate the new yield with the replaced operand.
6170
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
71+
assert(!loop.isDefinedOutsideOfLoop(yieldOp->getOperand(index)));
6272
rewriter.replaceAllUsesWith(yieldOp->getOperand(index), newYieldValue);
6373

6474
// Move the loop body to the new op.
6575
rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
66-
newLoop.getBody()->getArguments().take_front(
67-
loop.getBody()->getNumArguments()));
76+
newLoop.getBody()->getArguments());
6877

6978
// Replace the old loop.
70-
rewriter.replaceOp(loop.getOperation(),
71-
newLoop->getResults().take_front(loop.getNumResults()));
79+
rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
7280
return newLoop;
7381
}
7482

@@ -87,7 +95,8 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
8795
// scf.yield %u' : t2
8896
// }
8997
// %res = vector.broadcast %res' : t2 to t1
90-
void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
98+
void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
99+
Operation *root) {
91100
bool changed = true;
92101
while (changed) {
93102
changed = false;
@@ -104,21 +113,26 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
104113
if (!loop)
105114
return WalkResult::advance();
106115

107-
// Check that the vector to extract from is an iter_arg
116+
// Check that the vector to extract from is a BlockArgument.
108117
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
109-
if (!blockArg)
118+
if (!blockArg) {
119+
return WalkResult::advance();
120+
}
121+
122+
// Check that the blockArg is an iter_arg of the loop.
123+
OpOperand *initArg = loop.getTiedLoopInit(blockArg);
124+
if (!initArg)
110125
return WalkResult::advance();
111126

112127
// If the iter_arg does not have only one use, it won't be possible to
113128
// hoist the extractOp out.
114129
if (!blockArg.hasOneUse())
115130
return WalkResult::advance();
116131

117-
auto initArg = loop.getTiedLoopInit(blockArg)->get();
118-
auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
132+
unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
119133

120-
// Check that the loop yields a broadcast
121-
auto yieldedVal =
134+
// Check that the loop yields a broadcast.
135+
Operation *yieldedVal =
122136
loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
123137
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
124138
if (!broadcast)
@@ -131,26 +145,25 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
131145
return WalkResult::advance();
132146

133147
// The position of the extract must be defined outside of the loop if
134-
// it is dynamic
148+
// it is dynamic.
135149
for (auto operand : extractOp.getDynamicPosition())
136150
if (!loop.isDefinedOutsideOfLoop(operand))
137151
return WalkResult::advance();
138152

139-
IRRewriter rewriter(extractOp.getContext());
140-
141-
extractOp.getVectorMutable().assign(initArg);
153+
rewriter.modifyOpInPlace(broadcast, [&] {
154+
extractOp.getVectorMutable().assign(initArg->get());
155+
});
142156
loop.moveOutOfLoop(extractOp);
143157
rewriter.moveOpAfter(broadcast, loop);
144158

145-
auto newLoop = replaceWithDifferentYield(
159+
scf::ForOp newLoop = replaceWithDifferentYield(
146160
rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
147161

148162
LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
149163

150164
rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
151-
rewriter.modifyOpInPlace(broadcast, [&] {
152-
broadcast.getSourceMutable().assign(newLoop.getResult(index));
153-
});
165+
rewriter.modifyOpInPlace(
166+
broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
154167

155168
changed = true;
156169
return WalkResult::interrupt();

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,4 +673,4 @@ module attributes {transform.with_named_sequence} {
673673
: (!transform.any_op) -> !transform.any_op
674674
transform.yield
675675
}
676-
}
676+
}

0 commit comments

Comments
 (0)