@@ -486,6 +486,135 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp,
486
486
return success ();
487
487
}
488
488
489
+ LogicalResult mlir::loopUnrollJamUpToFactor (AffineForOp forOp,
490
+ uint64_t unrollJamFactor) {
491
+ Optional<uint64_t > mayBeConstantTripCount = getConstantTripCount (forOp);
492
+
493
+ if (mayBeConstantTripCount.hasValue () &&
494
+ mayBeConstantTripCount.getValue () < unrollJamFactor)
495
+ return loopUnrollJamByFactor (forOp, mayBeConstantTripCount.getValue ());
496
+ return loopUnrollJamByFactor (forOp, unrollJamFactor);
497
+ }
498
+
499
+ // / Unrolls and jams this loop by the specified factor.
500
+ LogicalResult mlir::loopUnrollJamByFactor (AffineForOp forOp,
501
+ uint64_t unrollJamFactor) {
502
+ // Gathers all maximal sub-blocks of operations that do not themselves
503
+ // include a for op (a operation could have a descendant for op though
504
+ // in its tree). Ignore the block terminators.
505
+ struct JamBlockGatherer {
506
+ // Store iterators to the first and last op of each sub-block found.
507
+ std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
508
+
509
+ // This is a linear time walk.
510
+ void walk (Operation *op) {
511
+ for (auto ®ion : op->getRegions ())
512
+ for (auto &block : region)
513
+ walk (block);
514
+ }
515
+ void walk (Block &block) {
516
+ for (auto it = block.begin (), e = std::prev (block.end ()); it != e;) {
517
+ auto subBlockStart = it;
518
+ while (it != e && !isa<AffineForOp>(&*it))
519
+ ++it;
520
+ if (it != subBlockStart)
521
+ subBlocks.push_back ({subBlockStart, std::prev (it)});
522
+ // Process all for insts that appear next.
523
+ while (it != e && isa<AffineForOp>(&*it))
524
+ walk (&*it++);
525
+ }
526
+ }
527
+ };
528
+
529
+ assert (unrollJamFactor >= 1 && " unroll jam factor should be >= 1" );
530
+
531
+ if (unrollJamFactor == 1 )
532
+ return promoteIfSingleIteration (forOp);
533
+
534
+ if (forOp.getBody ()->empty () ||
535
+ forOp.getBody ()->begin () == std::prev (forOp.getBody ()->end ()))
536
+ return failure ();
537
+
538
+ // Loops where both lower and upper bounds are multi-result maps won't be
539
+ // unrolled (since the trip can't be expressed as an affine function in
540
+ // general).
541
+ // TODO(mlir-team): this may not be common, but we could support the case
542
+ // where the lower bound is a multi-result map and the ub is a single result
543
+ // one.
544
+ if (forOp.getLowerBoundMap ().getNumResults () != 1 )
545
+ return failure ();
546
+
547
+ Optional<uint64_t > mayBeConstantTripCount = getConstantTripCount (forOp);
548
+ // If the trip count is lower than the unroll jam factor, no unroll jam.
549
+ if (mayBeConstantTripCount.hasValue () &&
550
+ mayBeConstantTripCount.getValue () < unrollJamFactor)
551
+ return failure ();
552
+
553
+ auto *forInst = forOp.getOperation ();
554
+
555
+ // Gather all sub-blocks to jam upon the loop being unrolled.
556
+ JamBlockGatherer jbg;
557
+ jbg.walk (forInst);
558
+ auto &subBlocks = jbg.subBlocks ;
559
+
560
+ // Generate the cleanup loop if trip count isn't a multiple of
561
+ // unrollJamFactor.
562
+ if (getLargestDivisorOfTripCount (forOp) % unrollJamFactor != 0 ) {
563
+ // Insert the cleanup loop right after 'forOp'.
564
+ OpBuilder builder (forInst->getBlock (), std::next (Block::iterator (forInst)));
565
+ auto cleanupAffineForOp = cast<AffineForOp>(builder.clone (*forInst));
566
+ // Adjust the lower bound of the cleanup loop; its upper bound is the same
567
+ // as the original loop's upper bound.
568
+ AffineMap cleanupMap;
569
+ SmallVector<Value, 4 > cleanupOperands;
570
+ getCleanupLoopLowerBound (forOp, unrollJamFactor, &cleanupMap,
571
+ &cleanupOperands, builder);
572
+ cleanupAffineForOp.setLowerBound (cleanupOperands, cleanupMap);
573
+
574
+ // Promote the cleanup loop if it has turned into a single iteration loop.
575
+ promoteIfSingleIteration (cleanupAffineForOp);
576
+
577
+ // Adjust the upper bound of the original loop - it will be the same as the
578
+ // cleanup loop's lower bound. Its lower bound remains unchanged.
579
+ forOp.setUpperBound (cleanupOperands, cleanupMap);
580
+ }
581
+
582
+ // Scale the step of loop being unroll-jammed by the unroll-jam factor.
583
+ int64_t step = forOp.getStep ();
584
+ forOp.setStep (step * unrollJamFactor);
585
+
586
+ auto forOpIV = forOp.getInductionVar ();
587
+ // Unroll and jam (appends unrollJamFactor - 1 additional copies).
588
+ for (unsigned i = unrollJamFactor - 1 ; i >= 1 ; --i) {
589
+ // Operand map persists across all sub-blocks.
590
+ BlockAndValueMapping operandMapping;
591
+ for (auto &subBlock : subBlocks) {
592
+ // Builder to insert unroll-jammed bodies. Insert right at the end of
593
+ // sub-block.
594
+ OpBuilder builder (subBlock.first ->getBlock (), std::next (subBlock.second ));
595
+
596
+ // If the induction variable is used, create a remapping to the value for
597
+ // this unrolled instance.
598
+ if (!forOpIV.use_empty ()) {
599
+ // iv' = iv + i, i = 1 to unrollJamFactor-1.
600
+ auto d0 = builder.getAffineDimExpr (0 );
601
+ auto bumpMap = AffineMap::get (1 , 0 , {d0 + i * step});
602
+ auto ivUnroll =
603
+ builder.create <AffineApplyOp>(forInst->getLoc (), bumpMap, forOpIV);
604
+ operandMapping.map (forOpIV, ivUnroll);
605
+ }
606
+ // Clone the sub-block being unroll-jammed.
607
+ for (auto it = subBlock.first ; it != std::next (subBlock.second ); ++it) {
608
+ builder.clone (*it, operandMapping);
609
+ }
610
+ }
611
+ }
612
+
613
+ // Promote the loop body up if this has turned into a single iteration loop.
614
+ promoteIfSingleIteration (forOp);
615
+ return success ();
616
+ }
617
+
489
618
// / Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
490
619
// / nested within 'forOpA' as the only non-terminator operation in its block.
491
620
void mlir::interchangeLoops (AffineForOp forOpA, AffineForOp forOpB) {
0 commit comments