Skip to content

Commit 5ebc4fb

Browse files
committed
review comments
1 parent 89ee43b commit 5ebc4fb

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2236,7 +2236,7 @@ def HoistRedundantVectorBroadcastsOp :
22362236
let extraClassDeclaration = [{
22372237
::mlir::DiagnosedSilenceableFailure applyToOne(
22382238
::mlir::transform::TransformRewriter &rewriter,
2239-
::mlir::func::FuncOp target,
2239+
::mlir::Operation *target,
22402240
::mlir::transform::ApplyToEachResultList &results,
22412241
::mlir::transform::TransformState &state);
22422242
}];

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3312,7 +3312,7 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
33123312

33133313
DiagnosedSilenceableFailure
33143314
transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3315-
transform::TransformRewriter &rewriter, func::FuncOp target,
3315+
transform::TransformRewriter &rewriter, mlir::Operation *target,
33163316
transform::ApplyToEachResultList &results,
33173317
transform::TransformState &state) {
33183318
linalg::hoistRedundantVectorBroadcasts(target);

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
5959

6060
// Generate the new yield with the replaced operand
6161
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
62-
yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
62+
rewriter.replaceAllUsesWith(yieldOp->getOperand(index), newYieldValue);
6363

6464
// Move the loop body to the new op.
6565
rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
@@ -74,19 +74,19 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
7474

7575
// Hoist out a pair of corresponding vector.extract+vector.broadcast
7676
// operations. This function transforms a loop like this:
77-
// %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
78-
// %e = vector.extract %iterarg : t1 to t2
79-
// %u = // do something with %e : t2
77+
// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
78+
// %e = vector.extract %iarg : t1 to t2
79+
// %u = "some_use"(%e) : (t2) -> t2
8080
// %b = vector.broadcast %u : t2 to t1
8181
// scf.yield %b : t1
8282
// }
8383
// into the following:
8484
// %e = vector.extract %v: t1 to t2
85-
// %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
86-
// %u' = // do something with %iterarg : t2
85+
// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
86+
// %u' = "some_use"(%iarg) : (t2) -> t2
8787
// scf.yield %u' : t2
8888
// }
89-
// %loop = vector.broadcast %loop' : t2 to t1
89+
// %res = vector.broadcast %res' : t2 to t1
9090
void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
9191
bool changed = true;
9292
while (changed) {
@@ -118,14 +118,12 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
118118
auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
119119

120120
// Check that the loop yields a broadcast
121-
auto lastOp = loop.getBody()->getTerminator();
122-
auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
123-
if (!yieldOp)
121+
auto yieldedVal =
122+
loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
123+
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
124+
if (!broadcast)
124125
return WalkResult::advance();
125126

126-
auto broadcast = dyn_cast<vector::BroadcastOp>(
127-
yieldOp->getOperand(index).getDefiningOp());
128-
129127
LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
130128

131129
Type broadcastInputType = broadcast.getSourceType();
@@ -138,18 +136,21 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
138136
if (!loop.isDefinedOutsideOfLoop(operand))
139137
return WalkResult::advance();
140138

139+
IRRewriter rewriter(extractOp.getContext());
140+
141141
extractOp.getVectorMutable().assign(initArg);
142142
loop.moveOutOfLoop(extractOp);
143-
broadcast->moveAfter(loop);
143+
rewriter.moveOpAfter(broadcast, loop);
144144

145-
IRRewriter rewriter(extractOp.getContext());
146145
auto newLoop = replaceWithDifferentYield(
147146
rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
148147

149148
LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
150149

151-
newLoop.getResult(index).replaceAllUsesWith(broadcast);
152-
broadcast.getSourceMutable().assign(newLoop.getResult(index));
150+
rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
151+
rewriter.modifyOpInPlace(broadcast, [&] {
152+
broadcast.getSourceMutable().assign(newLoop.getResult(index));
153+
});
153154

154155
changed = true;
155156
return WalkResult::interrupt();

0 commit comments

Comments
 (0)