@@ -120,8 +120,10 @@ namespace {
120
120
class MatchingSubsets {
121
121
public:
122
122
// / Insert a subset op.
123
- void insert (SubsetOpInterface op) {
123
+ void insert (SubsetOpInterface op, bool collectHoistableOps = true ) {
124
124
allSubsetOps.push_back (op);
125
+ if (!collectHoistableOps)
126
+ return ;
125
127
if (auto extractionOp =
126
128
dyn_cast<SubsetExtractionOpInterface>(op.getOperation ()))
127
129
insertExtractionOp (extractionOp);
@@ -148,6 +150,15 @@ class MatchingSubsets {
148
150
});
149
151
}
150
152
153
+ // / Populate subset ops starting from the given region iter_arg. Return
154
+ // / "failure" if non-subset ops are found along the path to the loop yielding
155
+ // / op or if there is no single path to the tied yielded operand. If
156
+ // / `collectHoistableOps` is set to "false", subset ops are gathered
157
+ // / throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
158
+ LogicalResult populateSubsetOpsAtIterArg (LoopLikeOpInterface loopLike,
159
+ BlockArgument iterArg,
160
+ bool collectHoistableOps = true );
161
+
151
162
private:
152
163
// / Helper function for equivalence of tensor values. Since only insertion
153
164
// / subset ops (that are also destination style ops) are followed when
@@ -225,18 +236,12 @@ static OpOperand *getSingleTerminatorUse(Value value) {
225
236
return nullptr ;
226
237
}
227
238
228
- // / Hoist all subset ops that operate on the idx-th region iter_arg of the given
229
- // / loop-like op and index into loop-invariant subset locations. Return the
230
- // / newly created loop op (that has extra iter_args) or the original loop op if
231
- // / nothing was hoisted.
232
- static LoopLikeOpInterface hoistSubsetAtIterArg (LoopLikeOpInterface loopLike,
233
- BlockArgument iterArg) {
234
- IRRewriter rewriter (loopLike.getContext ());
239
+ LogicalResult
240
+ MatchingSubsets::populateSubsetOpsAtIterArg (LoopLikeOpInterface loopLike,
241
+ BlockArgument iterArg,
242
+ bool collectHoistableOps) {
235
243
assert (iterArg.getOwner ()->getParentOp () == loopLike && " invalid iter_arg" );
236
- auto it = llvm::find (loopLike.getRegionIterArgs (), iterArg);
237
- int64_t iterArgIdx = std::distance (loopLike.getRegionIterArgs ().begin (), it);
238
244
Value value = iterArg;
239
- MatchingSubsets subsets;
240
245
241
246
// Traverse use-def chain. Subset ops can be hoisted only if all ops along the
242
247
// use-def chain starting from the region iter_arg are subset extraction or
@@ -249,36 +254,71 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
249
254
Value nextValue = {};
250
255
251
256
for (OpOperand &use : value.getUses ()) {
257
+ if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner ())) {
258
+ // Subset ops in nested loops are collected to check if there are only
259
+ // disjoint subset ops, but such subset ops are not subject to hoisting.
260
+ // To hoist subset ops from nested loops, the hoisting transformation
261
+ // should be run on the nested loop.
262
+ auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg (&use);
263
+ if (!nestedIterArg)
264
+ return failure ();
265
+ // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
266
+ // use-def chain starting at `nestedIterArg` and terminating in the
267
+ // tied, yielding operand.
268
+ if (failed (populateSubsetOpsAtIterArg (nestedLoop, nestedIterArg,
269
+ /* collectHoistableOps=*/ false )))
270
+ return failure ();
271
+ nextValue = nestedLoop.getTiedLoopResult (&use);
272
+ continue ;
273
+ }
274
+
252
275
auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner ());
253
276
if (!subsetOp)
254
- return loopLike ;
255
- subsets. insert (subsetOp);
277
+ return failure () ;
278
+ insert (subsetOp);
256
279
257
280
if (auto insertionOp =
258
281
dyn_cast<SubsetInsertionOpInterface>(use.getOwner ())) {
259
282
// The value must be used as a destination. (In case of a source, the
260
283
// entire tensor would be read, which would prevent any hoisting.)
261
284
if (&use != &insertionOp.getDestinationOperand ())
262
- return loopLike ;
285
+ return failure () ;
263
286
// There must be a single use-def chain from the region iter_arg to the
264
287
// terminator. I.e., only one insertion op. Branches are not supported.
265
288
if (nextValue)
266
- return loopLike ;
289
+ return failure () ;
267
290
nextValue = insertionOp.getUpdatedDestination ();
268
291
}
269
292
}
270
293
271
294
// Nothing can be hoisted if the chain does not continue with loop yielding
272
295
// op or a subset insertion op.
273
296
if (!nextValue)
274
- return loopLike ;
297
+ return failure () ;
275
298
value = nextValue;
276
299
}
277
300
278
301
// Hoist only if the SSA use-def chain ends in the yielding terminator of the
279
302
// loop and the yielded value is the `idx`-th operand. (I.e., there is no
280
303
// swapping yield.)
281
304
if (loopLike.getTiedLoopYieldedValue (iterArg) != yieldedOperand)
305
+ return failure ();
306
+
307
+ return success ();
308
+ }
309
+
310
+ // / Hoist all subset ops that operate on the idx-th region iter_arg of the given
311
+ // / loop-like op and index into loop-invariant subset locations. Return the
312
+ // / newly created loop op (that has extra iter_args) or the original loop op if
313
+ // / nothing was hoisted.
314
+ static LoopLikeOpInterface hoistSubsetAtIterArg (LoopLikeOpInterface loopLike,
315
+ BlockArgument iterArg) {
316
+ assert (iterArg.getOwner ()->getParentOp () == loopLike && " invalid iter_arg" );
317
+ auto it = llvm::find (loopLike.getRegionIterArgs (), iterArg);
318
+ int64_t iterArgIdx = std::distance (loopLike.getRegionIterArgs ().begin (), it);
319
+ IRRewriter rewriter (loopLike.getContext ());
320
+ MatchingSubsets subsets;
321
+ if (failed (subsets.populateSubsetOpsAtIterArg (loopLike, iterArg)))
282
322
return loopLike;
283
323
284
324
// Hoist all matching extraction-insertion pairs one-by-one.
0 commit comments