@@ -370,6 +370,120 @@ ValueVector mlir::scf::buildLoopNest(
370
370
});
371
371
}
372
372
373
+ namespace {
374
+ // Fold away ForOp iter arguments that are also yielded by the op.
375
+ // These arguments must be defined outside of the ForOp region and can just be
376
+ // forwarded after simplifying the op inits, yields and returns.
377
+ //
378
+ // The implementation uses `mergeBlockBefore` to steal the content of the
379
+ // original ForOp and avoid cloning.
380
+ struct ForOpIterArgsFolder : public OpRewritePattern <scf::ForOp> {
381
+ using OpRewritePattern<scf::ForOp>::OpRewritePattern;
382
+
383
+ LogicalResult matchAndRewrite (scf::ForOp forOp,
384
+ PatternRewriter &rewriter) const final {
385
+ bool canonicalize = false ;
386
+ Block &block = forOp.region ().front ();
387
+ auto yieldOp = cast<scf::YieldOp>(block.getTerminator ());
388
+
389
+ // An internal flat vector of block transfer
390
+ // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
391
+ // transformed block argument mappings. This plays the role of a
392
+ // BlockAndValueMapping for the particular use case of calling into
393
+ // `mergeBlockBefore`.
394
+ SmallVector<bool , 4 > keepMask;
395
+ keepMask.reserve (yieldOp.getNumOperands ());
396
+ SmallVector<Value, 4 > newBlockTransferArgs, newIterArgs, newYieldValues,
397
+ newResultValues;
398
+ newBlockTransferArgs.reserve (1 + forOp.getNumIterOperands ());
399
+ newBlockTransferArgs.push_back (Value ()); // iv placeholder with null value
400
+ newIterArgs.reserve (forOp.getNumIterOperands ());
401
+ newYieldValues.reserve (yieldOp.getNumOperands ());
402
+ newResultValues.reserve (forOp.getNumResults ());
403
+ for (auto it : llvm::zip (forOp.getIterOperands (), // iter from outside
404
+ forOp.getRegionIterArgs (), // iter inside region
405
+ yieldOp.getOperands () // iter yield
406
+ )) {
407
+ // Forwarded is `true` when the region `iter` argument is yielded.
408
+ bool forwarded = (std::get<1 >(it) == std::get<2 >(it));
409
+ keepMask.push_back (!forwarded);
410
+ canonicalize |= forwarded;
411
+ if (forwarded) {
412
+ newBlockTransferArgs.push_back (std::get<0 >(it));
413
+ newResultValues.push_back (std::get<0 >(it));
414
+ continue ;
415
+ }
416
+ newIterArgs.push_back (std::get<0 >(it));
417
+ newYieldValues.push_back (std::get<2 >(it));
418
+ newBlockTransferArgs.push_back (Value ()); // placeholder with null value
419
+ newResultValues.push_back (Value ()); // placeholder with null value
420
+ }
421
+
422
+ if (!canonicalize)
423
+ return failure ();
424
+
425
+ scf::ForOp newForOp = rewriter.create <scf::ForOp>(
426
+ forOp.getLoc (), forOp.lowerBound (), forOp.upperBound (), forOp.step (),
427
+ newIterArgs);
428
+ Block &newBlock = newForOp.region ().front ();
429
+
430
+ // Replace the null placeholders with newly constructed values.
431
+ newBlockTransferArgs[0 ] = newBlock.getArgument (0 ); // iv
432
+ for (unsigned idx = 0 , collapsedIdx = 0 , e = newResultValues.size ();
433
+ idx != e; ++idx) {
434
+ Value &blockTransferArg = newBlockTransferArgs[1 + idx];
435
+ Value &newResultVal = newResultValues[idx];
436
+ assert ((blockTransferArg && newResultVal) ||
437
+ (!blockTransferArg && !newResultVal));
438
+ if (!blockTransferArg) {
439
+ blockTransferArg = newForOp.getRegionIterArgs ()[collapsedIdx];
440
+ newResultVal = newForOp.getResult (collapsedIdx++);
441
+ }
442
+ }
443
+
444
+ Block &oldBlock = forOp.region ().front ();
445
+ assert (oldBlock.getNumArguments () == newBlockTransferArgs.size () &&
446
+ " unexpected argument size mismatch" );
447
+
448
+ // No results case: the scf::ForOp builder already created a zero
449
+ // reult terminator. Merge before this terminator and just get rid of the
450
+ // original terminator that has been merged in.
451
+ if (newIterArgs.empty ()) {
452
+ auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator ());
453
+ rewriter.mergeBlockBefore (&oldBlock, newYieldOp, newBlockTransferArgs);
454
+ rewriter.eraseOp (newBlock.getTerminator ()->getPrevNode ());
455
+ rewriter.replaceOp (forOp, newResultValues);
456
+ return success ();
457
+ }
458
+
459
+ // No terminator case: merge and rewrite the merged terminator.
460
+ auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
461
+ OpBuilder::InsertionGuard g (rewriter);
462
+ rewriter.setInsertionPoint (mergedTerminator);
463
+ SmallVector<Value, 4 > filteredOperands;
464
+ filteredOperands.reserve (newResultValues.size ());
465
+ for (unsigned idx = 0 , e = keepMask.size (); idx < e; ++idx)
466
+ if (keepMask[idx])
467
+ filteredOperands.push_back (mergedTerminator.getOperand (idx));
468
+ rewriter.create <scf::YieldOp>(mergedTerminator.getLoc (),
469
+ filteredOperands);
470
+ };
471
+
472
+ rewriter.mergeBlocks (&oldBlock, &newBlock, newBlockTransferArgs);
473
+ auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator ());
474
+ cloneFilteredTerminator (mergedYieldOp);
475
+ rewriter.eraseOp (mergedYieldOp);
476
+ rewriter.replaceOp (forOp, newResultValues);
477
+ return success ();
478
+ }
479
+ };
480
+ } // namespace
481
+
482
+ void ForOp::getCanonicalizationPatterns (OwningRewritePatternList &results,
483
+ MLIRContext *context) {
484
+ results.insert <ForOpIterArgsFolder>(context);
485
+ }
486
+
373
487
// ===----------------------------------------------------------------------===//
374
488
// IfOp
375
489
// ===----------------------------------------------------------------------===//
0 commit comments