@@ -43,32 +43,39 @@ using llvm::dbgs;
43
43
using namespace mlir ;
44
44
using namespace mlir ::linalg;
45
45
46
- scf::ForOp replaceWithDifferentYield (RewriterBase &rewriter, scf::ForOp loop,
47
- Value newInitOperand, int index,
48
- Value newYieldValue) {
46
+ // / Replace `loop` with a new loop that has a different init operand at
47
+ // / position `index`. The body of this loop is moved over to the new loop.
48
+ // /
49
+ // / `newInitOperands` specifies the replacement "init" operands.
50
+ // / `newYieldValue` is the replacement yield value of the loop at position
51
+ // / `index`.
52
+ static scf::ForOp replaceWithDifferentYield (RewriterBase &rewriter,
53
+ scf::ForOp loop,
54
+ Value newInitOperand,
55
+ unsigned index,
56
+ Value newYieldValue) {
49
57
OpBuilder::InsertionGuard g (rewriter);
50
58
rewriter.setInsertionPoint (loop.getOperation ());
51
59
auto inits = llvm::to_vector (loop.getInits ());
52
60
53
- // Replace the init value with the new operand
61
+ // Replace the init value with the new operand.
62
+ assert (index < inits.size ());
54
63
inits[index] = newInitOperand;
55
64
56
65
scf::ForOp newLoop = rewriter.create <scf::ForOp>(
57
66
loop.getLoc (), loop.getLowerBound (), loop.getUpperBound (), loop.getStep (),
58
67
inits, [](OpBuilder &, Location, Value, ValueRange) {});
59
68
60
- // Generate the new yield with the replaced operand
69
+ // Generate the new yield with the replaced operand.
61
70
auto yieldOp = cast<scf::YieldOp>(loop.getBody ()->getTerminator ());
62
- rewriter. replaceAllUsesWith (yieldOp-> getOperand ( index) , newYieldValue);
71
+ yieldOp. setOperand ( index, newYieldValue);
63
72
64
73
// Move the loop body to the new op.
65
74
rewriter.mergeBlocks (loop.getBody (), newLoop.getBody (),
66
- newLoop.getBody ()->getArguments ().take_front (
67
- loop.getBody ()->getNumArguments ()));
75
+ newLoop.getBody ()->getArguments ());
68
76
69
77
// Replace the old loop.
70
- rewriter.replaceOp (loop.getOperation (),
71
- newLoop->getResults ().take_front (loop.getNumResults ()));
78
+ rewriter.replaceOp (loop.getOperation (), newLoop->getResults ());
72
79
return newLoop;
73
80
}
74
81
@@ -87,7 +94,8 @@ scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
87
94
// scf.yield %u' : t2
88
95
// }
89
96
// %res = vector.broadcast %res' : t2 to t1
90
- void mlir::linalg::hoistRedundantVectorBroadcasts (Operation *root) {
97
+ void mlir::linalg::hoistRedundantVectorBroadcasts (RewriterBase &rewriter,
98
+ Operation *root) {
91
99
bool changed = true ;
92
100
while (changed) {
93
101
changed = false ;
@@ -104,24 +112,28 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
104
112
if (!loop)
105
113
return WalkResult::advance ();
106
114
107
- // Check that the vector to extract from is an iter_arg
115
+ // Check that the vector to extract from is a BlockArgument.
108
116
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector ());
109
117
if (!blockArg)
110
118
return WalkResult::advance ();
111
119
120
+ // Check that the blockArg is an iter_arg of the loop.
121
+ OpOperand *initArg = loop.getTiedLoopInit (blockArg);
122
+ if (!initArg)
123
+ return WalkResult::advance ();
124
+
112
125
// If the iter_arg does not have only one use, it won't be possible to
113
126
// hoist the extractOp out.
114
127
if (!blockArg.hasOneUse ())
115
128
return WalkResult::advance ();
116
129
117
- auto initArg = loop.getTiedLoopInit (blockArg)->get ();
118
- auto index = blockArg.getArgNumber () - loop.getNumInductionVars ();
130
+ unsigned index = blockArg.getArgNumber () - loop.getNumInductionVars ();
119
131
120
- // Check that the loop yields a broadcast
121
- auto yieldedVal =
132
+ // Check that the loop yields a broadcast that has just one use.
133
+ Operation * yieldedVal =
122
134
loop.getTiedLoopYieldedValue (blockArg)->get ().getDefiningOp ();
123
135
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
124
- if (!broadcast)
136
+ if (!broadcast || !broadcast. getResult (). hasOneUse () )
125
137
return WalkResult::advance ();
126
138
127
139
LLVM_DEBUG (DBGS () << " Candidate broadcast: " << broadcast << " \n " );
@@ -131,26 +143,25 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
131
143
return WalkResult::advance ();
132
144
133
145
// The position of the extract must be defined outside of the loop if
134
- // it is dynamic
146
+ // it is dynamic.
135
147
for (auto operand : extractOp.getDynamicPosition ())
136
148
if (!loop.isDefinedOutsideOfLoop (operand))
137
149
return WalkResult::advance ();
138
150
139
- IRRewriter rewriter (extractOp. getContext ());
140
-
141
- extractOp. getVectorMutable (). assign (initArg );
151
+ rewriter. modifyOpInPlace (broadcast, [&] {
152
+ extractOp. getVectorMutable (). assign (initArg-> get ());
153
+ } );
142
154
loop.moveOutOfLoop (extractOp);
143
155
rewriter.moveOpAfter (broadcast, loop);
144
156
145
- auto newLoop = replaceWithDifferentYield (
157
+ scf::ForOp newLoop = replaceWithDifferentYield (
146
158
rewriter, loop, extractOp.getResult (), index, broadcast.getSource ());
147
159
148
160
LLVM_DEBUG (DBGS () << " New loop: " << newLoop << " \n " );
149
161
150
162
rewriter.replaceAllUsesWith (newLoop.getResult (index), broadcast);
151
- rewriter.modifyOpInPlace (broadcast, [&] {
152
- broadcast.getSourceMutable ().assign (newLoop.getResult (index));
153
- });
163
+ rewriter.modifyOpInPlace (
164
+ broadcast, [&] { broadcast.setOperand (newLoop.getResult (index)); });
154
165
155
166
changed = true ;
156
167
return WalkResult::interrupt ();
0 commit comments