18
18
#include " mlir/Dialect/SCF/IR/SCF.h"
19
19
#include " mlir/IR/BuiltinOps.h"
20
20
#include " mlir/IR/IRMapping.h"
21
+ #include " mlir/IR/OpDefinition.h"
21
22
#include " mlir/IR/PatternMatch.h"
22
23
#include " mlir/Interfaces/SideEffectInterfaces.h"
23
24
#include " mlir/Support/MathExtras.h"
29
30
30
31
using namespace mlir ;
31
32
32
- namespace {
33
- // This structure is to pass and return sets of loop parameters without
34
- // confusing the order.
35
- struct LoopParams {
36
- Value lowerBound;
37
- Value upperBound;
38
- Value step;
39
- };
40
- } // namespace
41
-
42
33
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields (
43
34
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
44
35
ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -473,17 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
473
464
return success ();
474
465
}
475
466
476
- // / Transform a loop with a strictly positive step
477
- // / for %i = %lb to %ub step %s
478
- // / into a 0-based loop with step 1
479
- // / for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
480
- // / %i = %ii * %s + %lb
481
- // / Insert the induction variable remapping in the body of `inner`, which is
482
- // / expected to be either `loop` or another loop perfectly nested under `loop`.
483
- // / Insert the definition of new bounds immediate before `outer`, which is
484
- // / expected to be either `loop` or its parent in the loop nest.
485
- static LoopParams emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
486
- Value lb, Value ub, Value step) {
467
+ LoopParams mlir::emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
468
+ OpFoldResult lb, OpFoldResult ub,
469
+ OpFoldResult step) {
487
470
// For non-index types, generate `arith` instructions
488
471
// Check if the loop is already known to have a constant zero lower bound or
489
472
// a constant one step.
@@ -495,45 +478,58 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
495
478
if (auto stepCst = getConstantIntValue (step))
496
479
isStepOne = stepCst.value () == 1 ;
497
480
481
+ Type loopParamsType = getIntType (lb);
482
+ assert (loopParamsType == getIntType (ub) &&
483
+ loopParamsType == getIntType (step) && " expected matching types" );
484
+
498
485
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
499
486
// assuming the step is strictly positive. Update the bounds and the step
500
487
// of the loop to go from 0 to the number of iterations, if necessary.
501
488
if (isZeroBased && isStepOne)
502
489
return {lb, ub, step};
503
490
504
- Value diff = isZeroBased ? ub : rewriter.create <arith::SubIOp>(loc, ub, lb);
505
- Value newUpperBound =
506
- isStepOne ? diff : rewriter.create <arith::CeilDivSIOp>(loc, diff, step);
507
-
508
- Value newLowerBound = isZeroBased
509
- ? lb
510
- : rewriter.create <arith::ConstantOp>(
511
- loc, rewriter.getZeroAttr (lb.getType ()));
512
- Value newStep = isStepOne
513
- ? step
514
- : rewriter.create <arith::ConstantOp>(
515
- loc, rewriter.getIntegerAttr (step.getType (), 1 ));
491
+ OpFoldResult diff = isZeroBased ? ub
492
+ : rewriter.createOrFold <arith::SubIOp>(
493
+ loc,
494
+ getValueOrCreateConstantIntOp (
495
+ rewriter, loc, loopParamsType, ub),
496
+ getValueOrCreateConstantIntOp (
497
+ rewriter, loc, loopParamsType, lb));
498
+ OpFoldResult newUpperBound =
499
+ isStepOne ? diff
500
+ : rewriter.createOrFold <arith::CeilDivSIOp>(
501
+ loc,
502
+ getValueOrCreateConstantIntOp (rewriter, loc,
503
+ loopParamsType, diff),
504
+ getValueOrCreateConstantIntOp (rewriter, loc,
505
+ loopParamsType, step));
506
+
507
+ OpFoldResult newLowerBound = rewriter.getZeroAttr (loopParamsType);
508
+ OpFoldResult newStep = rewriter.getOneAttr (loopParamsType);
516
509
517
510
return {newLowerBound, newUpperBound, newStep};
518
511
}
519
512
520
- // / Get back the original induction variable values after loop normalization
521
- static void denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
522
- Value normalizedIv, Value origLb,
523
- Value origStep) {
513
+ void mlir::denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
514
+ Value normalizedIv, OpFoldResult origLb,
515
+ OpFoldResult origStep) {
524
516
Value denormalizedIv;
525
517
SmallPtrSet<Operation *, 2 > preserve;
526
518
bool isStepOne = isConstantIntValue (origStep, 1 );
527
519
bool isZeroBased = isConstantIntValue (origLb, 0 );
528
520
529
521
Value scaled = normalizedIv;
530
522
if (!isStepOne) {
531
- scaled = rewriter.create <arith::MulIOp>(loc, normalizedIv, origStep);
523
+ Value origStepValue = getValueOrCreateConstantIntOp (
524
+ rewriter, loc, getIntType (origStep), origStep);
525
+ scaled = rewriter.create <arith::MulIOp>(loc, normalizedIv, origStepValue);
532
526
preserve.insert (scaled.getDefiningOp ());
533
527
}
534
528
denormalizedIv = scaled;
535
529
if (!isZeroBased) {
536
- denormalizedIv = rewriter.create <arith::AddIOp>(loc, scaled, origLb);
530
+ Value origLbValue = getValueOrCreateConstantIntOp (
531
+ rewriter, loc, getIntType (origLb), origLb);
532
+ denormalizedIv = rewriter.create <arith::AddIOp>(loc, scaled, origLbValue);
537
533
preserve.insert (denormalizedIv.getDefiningOp ());
538
534
}
539
535
@@ -638,9 +634,13 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
638
634
emitNormalizedLoopBounds (rewriter, loop.getLoc (), lb, ub, step);
639
635
640
636
rewriter.modifyOpInPlace (loop, [&]() {
641
- loop.setLowerBound (newLoopParams.lowerBound );
642
- loop.setUpperBound (newLoopParams.upperBound );
643
- loop.setStep (newLoopParams.step );
637
+ Type loopParamsType = lb.getType ();
638
+ loop.setLowerBound (getValueOrCreateConstantIntOp (
639
+ rewriter, loop.getLoc (), loopParamsType, newLoopParams.lowerBound ));
640
+ loop.setUpperBound (getValueOrCreateConstantIntOp (
641
+ rewriter, loop.getLoc (), loopParamsType, newLoopParams.upperBound ));
642
+ loop.setStep (getValueOrCreateConstantIntOp (
643
+ rewriter, loop.getLoc (), loopParamsType, newLoopParams.step ));
644
644
});
645
645
646
646
rewriter.setInsertionPointToStart (innermost.getBody ());
@@ -778,18 +778,16 @@ void mlir::collapseParallelLoops(
778
778
llvm::sort (dims);
779
779
780
780
// Normalize ParallelOp's iteration pattern.
781
- SmallVector<Value, 3 > normalizedLowerBounds, normalizedSteps,
782
- normalizedUpperBounds;
781
+ SmallVector<Value, 3 > normalizedUpperBounds;
783
782
for (unsigned i = 0 , e = loops.getNumLoops (); i < e; ++i) {
784
783
OpBuilder::InsertionGuard g2 (rewriter);
785
784
rewriter.setInsertionPoint (loops);
786
785
Value lb = loops.getLowerBound ()[i];
787
786
Value ub = loops.getUpperBound ()[i];
788
787
Value step = loops.getStep ()[i];
789
788
auto newLoopParams = emitNormalizedLoopBounds (rewriter, loc, lb, ub, step);
790
- normalizedLowerBounds.push_back (newLoopParams.lowerBound );
791
- normalizedUpperBounds.push_back (newLoopParams.upperBound );
792
- normalizedSteps.push_back (newLoopParams.step );
789
+ normalizedUpperBounds.push_back (getValueOrCreateConstantIntOp (
790
+ rewriter, loops.getLoc (), ub.getType (), newLoopParams.upperBound ));
793
791
794
792
rewriter.setInsertionPointToStart (loops.getBody ());
795
793
denormalizeInductionVariable (rewriter, loc, loops.getInductionVars ()[i], lb,
0 commit comments