@@ -59,7 +59,7 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
59
59
60
60
// Generate the new yield with the replaced operand
61
61
auto yieldOp = cast<scf::YieldOp>(loop.getBody ()->getTerminator ());
62
- yieldOp->getOperand (index). replaceAllUsesWith ( newYieldValue);
62
+ rewriter. replaceAllUsesWith ( yieldOp->getOperand (index), newYieldValue);
63
63
64
64
// Move the loop body to the new op.
65
65
rewriter.mergeBlocks (loop.getBody (), newLoop.getBody (),
@@ -74,19 +74,19 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
74
74
75
75
// Hoist out a pair of corresponding vector.extract+vector.broadcast
76
76
// operations. This function transforms a loop like this:
77
- // %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
78
- // %e = vector.extract %iterarg : t1 to t2
79
- // %u = // do something with %e : t2
77
+ // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
78
+ // %e = vector.extract %iarg : t1 to t2
79
+ // %u = "some_use"(%e) : (t2) -> t2
80
80
// %b = vector.broadcast %u : t2 to t1
81
81
// scf.yield %b : t1
82
82
// }
83
83
// into the following:
84
84
// %e = vector.extract %v: t1 to t2
85
- // %loop ' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
86
- // %u' = // do something with %iterarg : t2
85
+ // %res ' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
86
+ // %u' = "some_use"(%iarg) : (t2) -> t2
87
87
// scf.yield %u' : t2
88
88
// }
89
- // %loop = vector.broadcast %loop ' : t2 to t1
89
+ // %res = vector.broadcast %res ' : t2 to t1
90
90
void mlir::linalg::hoistRedundantVectorBroadcasts (Operation *root) {
91
91
bool changed = true ;
92
92
while (changed) {
@@ -118,14 +118,12 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
118
118
auto index = blockArg.getArgNumber () - loop.getNumInductionVars ();
119
119
120
120
// Check that the loop yields a broadcast
121
- auto lastOp = loop.getBody ()->getTerminator ();
122
- auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
123
- if (!yieldOp)
121
+ auto yieldedVal =
122
+ loop.getTiedLoopYieldedValue (blockArg)->get ().getDefiningOp ();
123
+ auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
124
+ if (!broadcast)
124
125
return WalkResult::advance ();
125
126
126
- auto broadcast = dyn_cast<vector::BroadcastOp>(
127
- yieldOp->getOperand (index).getDefiningOp ());
128
-
129
127
LLVM_DEBUG (DBGS () << " Candidate broadcast: " << broadcast << " \n " );
130
128
131
129
Type broadcastInputType = broadcast.getSourceType ();
@@ -138,18 +136,21 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
138
136
if (!loop.isDefinedOutsideOfLoop (operand))
139
137
return WalkResult::advance ();
140
138
139
+ IRRewriter rewriter (extractOp.getContext ());
140
+
141
141
extractOp.getVectorMutable ().assign (initArg);
142
142
loop.moveOutOfLoop (extractOp);
143
- broadcast-> moveAfter ( loop);
143
+ rewriter. moveOpAfter (broadcast, loop);
144
144
145
- IRRewriter rewriter (extractOp.getContext ());
146
145
auto newLoop = replaceWithDifferentYield (
147
146
rewriter, loop, extractOp.getResult (), index, broadcast.getSource ());
148
147
149
148
LLVM_DEBUG (DBGS () << " New loop: " << newLoop << " \n " );
150
149
151
- newLoop.getResult (index).replaceAllUsesWith (broadcast);
152
- broadcast.getSourceMutable ().assign (newLoop.getResult (index));
150
+ rewriter.replaceAllUsesWith (newLoop.getResult (index), broadcast);
151
+ rewriter.modifyOpInPlace (broadcast, [&] {
152
+ broadcast.getSourceMutable ().assign (newLoop.getResult (index));
153
+ });
153
154
154
155
changed = true ;
155
156
return WalkResult::interrupt ();
0 commit comments