@@ -278,6 +278,21 @@ struct ErasedOpListener : public RewriterBase::Listener {
278
278
bool isErased (Operation *op) { return erased.count (op); }
279
279
};
280
280
281
+ // / Check if it is the ForOp that yield the result of inner loop
282
+ static LogicalResult isForOpYieldResultOfInnerLoop (LoopLikeOpInterface loop) {
283
+ if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ())) {
284
+ for (auto &&[index, op] :
285
+ llvm::enumerate (forOp.getBody ()->getOperations ())) {
286
+ // If the orderIndex of inner loop is the last second one before the
287
+ // yieldOp of ForOp, the given loop must yield the result of inner loop.
288
+ if (isa<LoopLikeOpInterface>(op)) {
289
+ return success ((index + 2 ) == forOp.getBody ()->getOperations ().size ());
290
+ }
291
+ }
292
+ }
293
+ return failure ();
294
+ }
295
+
281
296
// / Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
282
297
// / multi-level `extractSliceOp`. E.g.
283
298
// /
@@ -293,7 +308,9 @@ std::optional<scf::SCFFuseProducerOfSliceResult>
293
308
mlir::scfX::tileAndFuseProducerOfSlice (RewriterBase &rewriter,
294
309
Operation *candidateSliceOp) {
295
310
SmallVector<tensor::ExtractSliceOp> backwardSlice;
296
- if (failed (getRealProducerOfExtractSliceOp (candidateSliceOp, backwardSlice)))
311
+ FailureOr<OpResult> realProducer =
312
+ getRealProducerOfExtractSliceOp (candidateSliceOp, backwardSlice);
313
+ if (failed (realProducer))
297
314
return std::nullopt;
298
315
299
316
std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
@@ -303,14 +320,18 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
303
320
for (auto &&[index, sliceOp] : llvm::enumerate (backwardSlice)) {
304
321
// get nest loops between next candidate sliceOp and tiled producer.
305
322
auto whileProducerOutOfLoopBlock =
306
- [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
307
- if (fuseProducerResult) {
308
- Block &body = loop->getRegion (0 ).front ();
309
- if (fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
310
- ->getBlock () == &body)
311
- return failure ();
312
- }
313
- return success ();
323
+ [&fuseProducerResult,
324
+ &realProducer](LoopLikeOpInterface loop) -> LogicalResult {
325
+ // ensure that all surrounding outer loops are just yielding the result of
326
+ // the inner loops.
327
+ if (failed (isForOpYieldResultOfInnerLoop (loop)))
328
+ return failure ();
329
+ Operation *originalOp =
330
+ fuseProducerResult
331
+ ? fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
332
+ : realProducer->getDefiningOp ();
333
+ Block &body = loop->getRegion (0 ).front ();
334
+ return success (originalOp->getBlock () != &body);
314
335
};
315
336
SmallVector<LoopLikeOpInterface> outerLoops =
316
337
getOuterNestLoopsWhile (sliceOp->getParentOfType <LoopLikeOpInterface>(),
@@ -515,21 +536,6 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
515
536
return operand;
516
537
}
517
538
518
- // / Check if it is the ForOp that yield the result of inner loop
519
- static LogicalResult isForOpYieldResultOfInnerLoop (LoopLikeOpInterface loop) {
520
- if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ())) {
521
- for (auto &&[index, op] :
522
- llvm::enumerate (forOp.getBody ()->getOperations ())) {
523
- // If the orderIndex of inner loop is the last second one before the
524
- // yieldOp of ForOp, the given loop must yield the result of inner loop.
525
- if (isa<LoopLikeOpInterface>(op)) {
526
- return success ((index + 2 ) == forOp.getBody ()->getOperations ().size ());
527
- }
528
- }
529
- }
530
- return failure ();
531
- }
532
-
533
539
// / Fetch the untiled consumer of a scf.for's result which is yielded by a
534
540
// / tensor.insert_slice. This function makes the following assumptions that
535
541
// / tensor.insert_slice has scf.yield as its only user.
0 commit comments