Skip to content

Commit bf4eb78

Browse files
committed
review comments
1 parent 854e4d1 commit bf4eb78

File tree

4 files changed

+49
-28
lines changed

4 files changed

+49
-28
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,16 @@ namespace linalg {
4343
/// when used on distributed loops with memref semantics!
4444
void hoistRedundantVectorTransfers(Operation *root);
4545

46-
void hoistRedundantVectorBroadcasts(Operation *root);
46+
/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
47+
/// scf::ForOp iteratively, if the following conditions are met:
48+
/// 1. The vector.extract operation is applied on an iter_argument, and no
49+
/// other operator is using this argument in the body of the loop.
50+
/// 2. The position of the vector.extract is either a static value, or defined
51+
/// outside of the loop.
52+
/// 3. The vector.broadcast operation is yielded by the loop.
53+
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
54+
/// function on the candidate loop above which to hoist.
55+
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root);
4756

4857
} // namespace linalg
4958
} // namespace mlir

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3315,7 +3315,8 @@ transform::HoistRedundantVectorBroadcastsOp::applyToOne(
33153315
transform::TransformRewriter &rewriter, mlir::Operation *target,
33163316
transform::ApplyToEachResultList &results,
33173317
transform::TransformState &state) {
3318-
linalg::hoistRedundantVectorBroadcasts(target);
3318+
rewriter.setInsertionPoint(target);
3319+
linalg::hoistRedundantVectorBroadcasts(rewriter, target);
33193320
results.push_back(target);
33203321
return DiagnosedSilenceableFailure::success();
33213322
}

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,39 @@ using llvm::dbgs;
4343
using namespace mlir;
4444
using namespace mlir::linalg;
4545

46-
scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
47-
Value newInitOperand, int index,
48-
Value newYieldValue) {
46+
/// Replace `loop` with a new loop that has a different init operand at
47+
/// position `index`. The body of this loop is moved over to the new loop.
48+
///
49+
/// `newInitOperands` specifies the replacement "init" operands.
50+
/// `newYieldValue` is the replacement yield value of the loop at position
51+
/// `index`.
52+
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
53+
scf::ForOp loop,
54+
Value newInitOperand,
55+
unsigned index,
56+
Value newYieldValue) {
4957
OpBuilder::InsertionGuard g(rewriter);
5058
rewriter.setInsertionPoint(loop.getOperation());
5159
auto inits = llvm::to_vector(loop.getInits());
5260

53-
// Replace the init value with the new operand
61+
// Replace the init value with the new operand.
62+
assert(index < inits.size());
5463
inits[index] = newInitOperand;
5564

5665
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
5766
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
5867
inits, [](OpBuilder &, Location, Value, ValueRange) {});
5968

60-
// Generate the new yield with the replaced operand
69+
// Generate the new yield with the replaced operand.
6170
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
62-
rewriter.replaceAllUsesWith(yieldOp->getOperand(index), newYieldValue);
71+
yieldOp.setOperand(index, newYieldValue);
6372

6473
// Move the loop body to the new op.
6574
rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
66-
newLoop.getBody()->getArguments().take_front(
67-
loop.getBody()->getNumArguments()));
75+
newLoop.getBody()->getArguments());
6876

6977
// Replace the old loop.
70-
rewriter.replaceOp(loop.getOperation(),
71-
newLoop->getResults().take_front(loop.getNumResults()));
78+
rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
7279
return newLoop;
7380
}
7481

@@ -87,7 +94,8 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
8794
// scf.yield %u' : t2
8895
// }
8996
// %res = vector.broadcast %res' : t2 to t1
90-
void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
97+
void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
98+
Operation *root) {
9199
bool changed = true;
92100
while (changed) {
93101
changed = false;
@@ -104,24 +112,28 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
104112
if (!loop)
105113
return WalkResult::advance();
106114

107-
// Check that the vector to extract from is an iter_arg
115+
// Check that the vector to extract from is a BlockArgument.
108116
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
109117
if (!blockArg)
110118
return WalkResult::advance();
111119

120+
// Check that the blockArg is an iter_arg of the loop.
121+
OpOperand *initArg = loop.getTiedLoopInit(blockArg);
122+
if (!initArg)
123+
return WalkResult::advance();
124+
112125
// If the iter_arg does not have only one use, it won't be possible to
113126
// hoist the extractOp out.
114127
if (!blockArg.hasOneUse())
115128
return WalkResult::advance();
116129

117-
auto initArg = loop.getTiedLoopInit(blockArg)->get();
118-
auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
130+
unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
119131

120-
// Check that the loop yields a broadcast
121-
auto yieldedVal =
132+
// Check that the loop yields a broadcast that has just one use.
133+
Operation *yieldedVal =
122134
loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
123135
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
124-
if (!broadcast)
136+
if (!broadcast || !broadcast.getResult().hasOneUse())
125137
return WalkResult::advance();
126138

127139
LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
@@ -131,26 +143,25 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
131143
return WalkResult::advance();
132144

133145
// The position of the extract must be defined outside of the loop if
134-
// it is dynamic
146+
// it is dynamic.
135147
for (auto operand : extractOp.getDynamicPosition())
136148
if (!loop.isDefinedOutsideOfLoop(operand))
137149
return WalkResult::advance();
138150

139-
IRRewriter rewriter(extractOp.getContext());
140-
141-
extractOp.getVectorMutable().assign(initArg);
151+
rewriter.modifyOpInPlace(broadcast, [&] {
152+
extractOp.getVectorMutable().assign(initArg->get());
153+
});
142154
loop.moveOutOfLoop(extractOp);
143155
rewriter.moveOpAfter(broadcast, loop);
144156

145-
auto newLoop = replaceWithDifferentYield(
157+
scf::ForOp newLoop = replaceWithDifferentYield(
146158
rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
147159

148160
LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
149161

150162
rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
151-
rewriter.modifyOpInPlace(broadcast, [&] {
152-
broadcast.getSourceMutable().assign(newLoop.getResult(index));
153-
});
163+
rewriter.modifyOpInPlace(
164+
broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
154165

155166
changed = true;
156167
return WalkResult::interrupt();

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,4 +673,4 @@ module attributes {transform.with_named_sequence} {
673673
: (!transform.any_op) -> !transform.any_op
674674
transform.yield
675675
}
676-
}
676+
}

0 commit comments

Comments
 (0)