18
18
#include " mlir/Dialect/SCF/IR/SCF.h"
19
19
#include " mlir/IR/BuiltinOps.h"
20
20
#include " mlir/IR/IRMapping.h"
21
+ #include " mlir/IR/OpDefinition.h"
21
22
#include " mlir/IR/PatternMatch.h"
22
23
#include " mlir/Interfaces/SideEffectInterfaces.h"
23
24
#include " mlir/Support/MathExtras.h"
29
30
30
31
using namespace mlir ;
31
32
32
- namespace {
33
- // This structure is to pass and return sets of loop parameters without
34
- // confusing the order.
35
- struct LoopParams {
36
- Value lowerBound;
37
- Value upperBound;
38
- Value step;
39
- };
40
- } // namespace
41
-
42
33
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields (
43
34
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
44
35
ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -473,17 +464,8 @@ LogicalResult mlir::loopUnrollByFactor(
473
464
return success ();
474
465
}
475
466
476
- // / Transform a loop with a strictly positive step
477
- // / for %i = %lb to %ub step %s
478
- // / into a 0-based loop with step 1
479
- // / for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
480
- // / %i = %ii * %s + %lb
481
- // / Insert the induction variable remapping in the body of `inner`, which is
482
- // / expected to be either `loop` or another loop perfectly nested under `loop`.
483
- // / Insert the definition of new bounds immediate before `outer`, which is
484
- // / expected to be either `loop` or its parent in the loop nest.
485
- static LoopParams emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
486
- Value lb, Value ub, Value step) {
467
+ LoopParams mlir::emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
468
+ Value lb, Value ub, Value step) {
487
469
// For non-index types, generate `arith` instructions
488
470
// Check if the loop is already known to have a constant zero lower bound or
489
471
// a constant one step.
@@ -501,26 +483,27 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
501
483
if (isZeroBased && isStepOne)
502
484
return {lb, ub, step};
503
485
504
- Value diff = isZeroBased ? ub : rewriter.create <arith::SubIOp>(loc, ub, lb);
505
- Value newUpperBound =
506
- isStepOne ? diff : rewriter.create <arith::CeilDivSIOp>(loc, diff, step);
507
-
508
- Value newLowerBound = isZeroBased
509
- ? lb
510
- : rewriter.create <arith::ConstantOp>(
511
- loc, rewriter.getZeroAttr (lb.getType ()));
512
- Value newStep = isStepOne
513
- ? step
514
- : rewriter.create <arith::ConstantOp>(
515
- loc, rewriter.getIntegerAttr (step.getType (), 1 ));
486
+ auto diff =
487
+ isZeroBased ? ub : rewriter.createOrFold <arith::SubIOp>(loc, ub, lb);
488
+ auto newUpperBound =
489
+ isStepOne ? diff
490
+ : rewriter.createOrFold <arith::CeilDivSIOp>(loc, diff, step);
491
+
492
+ auto newLowerBound = isZeroBased
493
+ ? lb
494
+ : rewriter.createOrFold <arith::ConstantOp>(
495
+ loc, rewriter.getZeroAttr (lb.getType ()));
496
+ auto newStep = isStepOne
497
+ ? step
498
+ : rewriter.createOrFold <arith::ConstantOp>(
499
+ loc, rewriter.getIntegerAttr (step.getType (), 1 ));
516
500
517
501
return {newLowerBound, newUpperBound, newStep};
518
502
}
519
503
520
- // / Get back the original induction variable values after loop normalization
521
- static void denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
522
- Value normalizedIv, Value origLb,
523
- Value origStep) {
504
+ void mlir::denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
505
+ Value normalizedIv, Value origLb,
506
+ Value origStep) {
524
507
Value denormalizedIv;
525
508
SmallPtrSet<Operation *, 2 > preserve;
526
509
bool isStepOne = isConstantIntValue (origStep, 1 );
@@ -638,9 +621,9 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
638
621
emitNormalizedLoopBounds (rewriter, loop.getLoc (), lb, ub, step);
639
622
640
623
rewriter.modifyOpInPlace (loop, [&]() {
641
- loop.setLowerBound (newLoopParams.lowerBound );
642
- loop.setUpperBound (newLoopParams.upperBound );
643
- loop.setStep (newLoopParams.step );
624
+ loop.setLowerBound (cast<Value>( newLoopParams.lowerBound ) );
625
+ loop.setUpperBound (cast<Value>( newLoopParams.upperBound ) );
626
+ loop.setStep (cast<Value>( newLoopParams.step ) );
644
627
});
645
628
646
629
rewriter.setInsertionPointToStart (innermost.getBody ());
@@ -778,18 +761,15 @@ void mlir::collapseParallelLoops(
778
761
llvm::sort (dims);
779
762
780
763
// Normalize ParallelOp's iteration pattern.
781
- SmallVector<Value, 3 > normalizedLowerBounds, normalizedSteps,
782
- normalizedUpperBounds;
764
+ SmallVector<Value, 3 > normalizedUpperBounds;
783
765
for (unsigned i = 0 , e = loops.getNumLoops (); i < e; ++i) {
784
766
OpBuilder::InsertionGuard g2 (rewriter);
785
767
rewriter.setInsertionPoint (loops);
786
768
Value lb = loops.getLowerBound ()[i];
787
769
Value ub = loops.getUpperBound ()[i];
788
770
Value step = loops.getStep ()[i];
789
771
auto newLoopParams = emitNormalizedLoopBounds (rewriter, loc, lb, ub, step);
790
- normalizedLowerBounds.push_back (newLoopParams.lowerBound );
791
- normalizedUpperBounds.push_back (newLoopParams.upperBound );
792
- normalizedSteps.push_back (newLoopParams.step );
772
+ normalizedUpperBounds.push_back (cast<Value>(newLoopParams.upperBound ));
793
773
794
774
rewriter.setInsertionPointToStart (loops.getBody ());
795
775
denormalizeInductionVariable (rewriter, loc, loops.getInductionVars ()[i], lb,
0 commit comments