18
18
#include " mlir/Dialect/SCF/IR/SCF.h"
19
19
#include " mlir/IR/BuiltinOps.h"
20
20
#include " mlir/IR/IRMapping.h"
21
+ #include " mlir/IR/OpDefinition.h"
21
22
#include " mlir/IR/PatternMatch.h"
22
23
#include " mlir/Interfaces/SideEffectInterfaces.h"
23
24
#include " mlir/Transforms/RegionUtils.h"
29
30
30
31
using namespace mlir ;
31
32
32
- namespace {
33
- // This structure is to pass and return sets of loop parameters without
34
- // confusing the order.
35
- struct LoopParams {
36
- Value lowerBound;
37
- Value upperBound;
38
- Value step;
39
- };
40
- } // namespace
41
-
42
33
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields (
43
34
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
44
35
ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -473,17 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
473
464
return success ();
474
465
}
475
466
476
- // / Transform a loop with a strictly positive step
477
- // / for %i = %lb to %ub step %s
478
- // / into a 0-based loop with step 1
479
- // / for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
480
- // / %i = %ii * %s + %lb
481
- // / Insert the induction variable remapping in the body of `inner`, which is
482
- // / expected to be either `loop` or another loop perfectly nested under `loop`.
483
- // / Insert the definition of new bounds immediate before `outer`, which is
484
- // / expected to be either `loop` or its parent in the loop nest.
485
- static LoopParams emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
486
- Value lb, Value ub, Value step) {
467
+ LoopParams mlir::emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
468
+ OpFoldResult lb, OpFoldResult ub,
469
+ OpFoldResult step) {
487
470
// For non-index types, generate `arith` instructions
488
471
// Check if the loop is already known to have a constant zero lower bound or
489
472
// a constant one step.
@@ -495,45 +478,54 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
495
478
if (auto stepCst = getConstantIntValue (step))
496
479
isStepOne = stepCst.value () == 1 ;
497
480
481
+ Type loopParamsType = getType (lb);
482
+ assert (loopParamsType == getType (ub) && loopParamsType == getType (step) &&
483
+ " expected matching types" );
484
+
498
485
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
499
486
// assuming the step is strictly positive. Update the bounds and the step
500
487
// of the loop to go from 0 to the number of iterations, if necessary.
501
488
if (isZeroBased && isStepOne)
502
489
return {lb, ub, step};
503
490
504
- Value diff = isZeroBased ? ub : rewriter.create <arith::SubIOp>(loc, ub, lb);
505
- Value newUpperBound =
506
- isStepOne ? diff : rewriter.create <arith::CeilDivSIOp>(loc, diff, step);
491
+ OpFoldResult diff = ub;
492
+ if (!isZeroBased) {
493
+ diff = rewriter.createOrFold <arith::SubIOp>(
494
+ loc, getValueOrCreateConstantIntOp (rewriter, loc, ub),
495
+ getValueOrCreateConstantIntOp (rewriter, loc, lb));
496
+ }
497
+ OpFoldResult newUpperBound = diff;
498
+ if (!isStepOne) {
499
+ newUpperBound = rewriter.createOrFold <arith::CeilDivSIOp>(
500
+ loc, getValueOrCreateConstantIntOp (rewriter, loc, diff),
501
+ getValueOrCreateConstantIntOp (rewriter, loc, step));
502
+ }
507
503
508
- Value newLowerBound = isZeroBased
509
- ? lb
510
- : rewriter.create <arith::ConstantOp>(
511
- loc, rewriter.getZeroAttr (lb.getType ()));
512
- Value newStep = isStepOne
513
- ? step
514
- : rewriter.create <arith::ConstantOp>(
515
- loc, rewriter.getIntegerAttr (step.getType (), 1 ));
504
+ OpFoldResult newLowerBound = rewriter.getZeroAttr (loopParamsType);
505
+ OpFoldResult newStep = rewriter.getOneAttr (loopParamsType);
516
506
517
507
return {newLowerBound, newUpperBound, newStep};
518
508
}
519
509
520
- // / Get back the original induction variable values after loop normalization
521
- static void denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
522
- Value normalizedIv, Value origLb,
523
- Value origStep) {
510
+ void mlir::denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
511
+ Value normalizedIv, OpFoldResult origLb,
512
+ OpFoldResult origStep) {
524
513
Value denormalizedIv;
525
514
SmallPtrSet<Operation *, 2 > preserve;
526
515
bool isStepOne = isConstantIntValue (origStep, 1 );
527
516
bool isZeroBased = isConstantIntValue (origLb, 0 );
528
517
529
518
Value scaled = normalizedIv;
530
519
if (!isStepOne) {
531
- scaled = rewriter.create <arith::MulIOp>(loc, normalizedIv, origStep);
520
+ Value origStepValue =
521
+ getValueOrCreateConstantIntOp (rewriter, loc, origStep);
522
+ scaled = rewriter.create <arith::MulIOp>(loc, normalizedIv, origStepValue);
532
523
preserve.insert (scaled.getDefiningOp ());
533
524
}
534
525
denormalizedIv = scaled;
535
526
if (!isZeroBased) {
536
- denormalizedIv = rewriter.create <arith::AddIOp>(loc, scaled, origLb);
527
+ Value origLbValue = getValueOrCreateConstantIntOp (rewriter, loc, origLb);
528
+ denormalizedIv = rewriter.create <arith::AddIOp>(loc, scaled, origLbValue);
537
529
preserve.insert (denormalizedIv.getDefiningOp ());
538
530
}
539
531
@@ -638,9 +630,12 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
638
630
emitNormalizedLoopBounds (rewriter, loop.getLoc (), lb, ub, step);
639
631
640
632
rewriter.modifyOpInPlace (loop, [&]() {
641
- loop.setLowerBound (newLoopParams.lowerBound );
642
- loop.setUpperBound (newLoopParams.upperBound );
643
- loop.setStep (newLoopParams.step );
633
+ loop.setLowerBound (getValueOrCreateConstantIntOp (
634
+ rewriter, loop.getLoc (), newLoopParams.lowerBound ));
635
+ loop.setUpperBound (getValueOrCreateConstantIntOp (
636
+ rewriter, loop.getLoc (), newLoopParams.upperBound ));
637
+ loop.setStep (getValueOrCreateConstantIntOp (rewriter, loop.getLoc (),
638
+ newLoopParams.step ));
644
639
});
645
640
646
641
rewriter.setInsertionPointToStart (innermost.getBody ());
@@ -778,18 +773,16 @@ void mlir::collapseParallelLoops(
778
773
llvm::sort (dims);
779
774
780
775
// Normalize ParallelOp's iteration pattern.
781
- SmallVector<Value, 3 > normalizedLowerBounds, normalizedSteps,
782
- normalizedUpperBounds;
776
+ SmallVector<Value, 3 > normalizedUpperBounds;
783
777
for (unsigned i = 0 , e = loops.getNumLoops (); i < e; ++i) {
784
778
OpBuilder::InsertionGuard g2 (rewriter);
785
779
rewriter.setInsertionPoint (loops);
786
780
Value lb = loops.getLowerBound ()[i];
787
781
Value ub = loops.getUpperBound ()[i];
788
782
Value step = loops.getStep ()[i];
789
783
auto newLoopParams = emitNormalizedLoopBounds (rewriter, loc, lb, ub, step);
790
- normalizedLowerBounds.push_back (newLoopParams.lowerBound );
791
- normalizedUpperBounds.push_back (newLoopParams.upperBound );
792
- normalizedSteps.push_back (newLoopParams.step );
784
+ normalizedUpperBounds.push_back (getValueOrCreateConstantIntOp (
785
+ rewriter, loops.getLoc (), newLoopParams.upperBound ));
793
786
794
787
rewriter.setInsertionPointToStart (loops.getBody ());
795
788
denormalizeInductionVariable (rewriter, loc, loops.getInductionVars ()[i], lb,
0 commit comments