@@ -43,32 +43,40 @@ using llvm::dbgs;
43
43
using namespace mlir ;
44
44
using namespace mlir ::linalg;
45
45
46
- scf::ForOp replaceWithDifferentYield (RewriterBase &rewriter, scf::ForOp loop,
47
- Value newInitOperand, int index,
48
- Value newYieldValue) {
46
+ // / Replace `loop` with a new loop that has a different init operand at
47
+ // / position `index`. The body of this loop is moved over to the new loop.
48
+ // /
49
+ // / `newInitOperands` specifies the replacement "init" operands.
50
+ // / `newYieldValue` is the replacement yield value of the loop at position
51
+ // / `index`.
52
+ static scf::ForOp replaceWithDifferentYield (RewriterBase &rewriter,
53
+ scf::ForOp loop,
54
+ Value newInitOperand,
55
+ unsigned index,
56
+ Value newYieldValue) {
49
57
OpBuilder::InsertionGuard g (rewriter);
50
58
rewriter.setInsertionPoint (loop.getOperation ());
51
59
auto inits = llvm::to_vector (loop.getInits ());
52
60
53
- // Replace the init value with the new operand
61
+ // Replace the init value with the new operand.
62
+ assert (index < inits.size ());
54
63
inits[index] = newInitOperand;
55
64
56
65
scf::ForOp newLoop = rewriter.create <scf::ForOp>(
57
66
loop.getLoc (), loop.getLowerBound (), loop.getUpperBound (), loop.getStep (),
58
67
inits, [](OpBuilder &, Location, Value, ValueRange) {});
59
68
60
- // Generate the new yield with the replaced operand
69
+ // Generate the new yield with the replaced operand.
61
70
auto yieldOp = cast<scf::YieldOp>(loop.getBody ()->getTerminator ());
71
+ assert (!loop.isDefinedOutsideOfLoop (yieldOp->getOperand (index)));
62
72
rewriter.replaceAllUsesWith (yieldOp->getOperand (index), newYieldValue);
63
73
64
74
// Move the loop body to the new op.
65
75
rewriter.mergeBlocks (loop.getBody (), newLoop.getBody (),
66
- newLoop.getBody ()->getArguments ().take_front (
67
- loop.getBody ()->getNumArguments ()));
76
+ newLoop.getBody ()->getArguments ());
68
77
69
78
// Replace the old loop.
70
- rewriter.replaceOp (loop.getOperation (),
71
- newLoop->getResults ().take_front (loop.getNumResults ()));
79
+ rewriter.replaceOp (loop.getOperation (), newLoop->getResults ());
72
80
return newLoop;
73
81
}
74
82
@@ -87,7 +95,8 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
87
95
// scf.yield %u' : t2
88
96
// }
89
97
// %res = vector.broadcast %res' : t2 to t1
90
- void mlir::linalg::hoistRedundantVectorBroadcasts (Operation *root) {
98
+ void mlir::linalg::hoistRedundantVectorBroadcasts (RewriterBase &rewriter,
99
+ Operation *root) {
91
100
bool changed = true ;
92
101
while (changed) {
93
102
changed = false ;
@@ -104,21 +113,26 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
104
113
if (!loop)
105
114
return WalkResult::advance ();
106
115
107
- // Check that the vector to extract from is an iter_arg
116
+ // Check that the vector to extract from is a BlockArgument.
108
117
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector ());
109
- if (!blockArg)
118
+ if (!blockArg) {
119
+ return WalkResult::advance ();
120
+ }
121
+
122
+ // Check that the blockArg is an iter_arg of the loop.
123
+ OpOperand *initArg = loop.getTiedLoopInit (blockArg);
124
+ if (!initArg)
110
125
return WalkResult::advance ();
111
126
112
127
// If the iter_arg does not have only one use, it won't be possible to
113
128
// hoist the extractOp out.
114
129
if (!blockArg.hasOneUse ())
115
130
return WalkResult::advance ();
116
131
117
- auto initArg = loop.getTiedLoopInit (blockArg)->get ();
118
- auto index = blockArg.getArgNumber () - loop.getNumInductionVars ();
132
+ unsigned index = blockArg.getArgNumber () - loop.getNumInductionVars ();
119
133
120
- // Check that the loop yields a broadcast
121
- auto yieldedVal =
134
+ // Check that the loop yields a broadcast.
135
+ Operation * yieldedVal =
122
136
loop.getTiedLoopYieldedValue (blockArg)->get ().getDefiningOp ();
123
137
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
124
138
if (!broadcast)
@@ -131,26 +145,25 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
131
145
return WalkResult::advance ();
132
146
133
147
// The position of the extract must be defined outside of the loop if
134
- // it is dynamic
148
+ // it is dynamic.
135
149
for (auto operand : extractOp.getDynamicPosition ())
136
150
if (!loop.isDefinedOutsideOfLoop (operand))
137
151
return WalkResult::advance ();
138
152
139
- IRRewriter rewriter (extractOp. getContext ());
140
-
141
- extractOp. getVectorMutable (). assign (initArg );
153
+ rewriter. modifyOpInPlace (broadcast, [&] {
154
+ extractOp. getVectorMutable (). assign (initArg-> get ());
155
+ } );
142
156
loop.moveOutOfLoop (extractOp);
143
157
rewriter.moveOpAfter (broadcast, loop);
144
158
145
- auto newLoop = replaceWithDifferentYield (
159
+ scf::ForOp newLoop = replaceWithDifferentYield (
146
160
rewriter, loop, extractOp.getResult (), index, broadcast.getSource ());
147
161
148
162
LLVM_DEBUG (DBGS () << " New loop: " << newLoop << " \n " );
149
163
150
164
rewriter.replaceAllUsesWith (newLoop.getResult (index), broadcast);
151
- rewriter.modifyOpInPlace (broadcast, [&] {
152
- broadcast.getSourceMutable ().assign (newLoop.getResult (index));
153
- });
165
+ rewriter.modifyOpInPlace (
166
+ broadcast, [&] { broadcast.setOperand (newLoop.getResult (index)); });
154
167
155
168
changed = true ;
156
169
return WalkResult::interrupt ();
0 commit comments