Skip to content

Commit 2ecb1ab

Browse files
authored
[mlir][scf]: Removed LoopParams struct and used Range instead (NFC) (#95501)
1 parent 4368726 commit 2ecb1ab

File tree

2 files changed

+18
-27
lines changed

2 files changed

+18
-27
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,6 @@ LogicalResult loopUnrollByFactor(
120120
scf::ForOp forOp, uint64_t unrollFactor,
121121
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
122122

123-
/// This structure is to pass and return sets of loop parameters without
124-
/// confusing the order.
125-
struct LoopParams {
126-
OpFoldResult lowerBound;
127-
OpFoldResult upperBound;
128-
OpFoldResult step;
129-
};
130-
131123
/// Transform a loop with a strictly positive step
132124
/// for %i = %lb to %ub step %s
133125
/// into a 0-based loop with step 1
@@ -137,9 +129,9 @@ struct LoopParams {
137129
/// expected to be either `loop` or another loop perfectly nested under `loop`.
138130
/// Insert the definition of new bounds immediate before `outer`, which is
139131
/// expected to be either `loop` or its parent in the loop nest.
140-
LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
141-
OpFoldResult lb, OpFoldResult ub,
142-
OpFoldResult step);
132+
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
133+
OpFoldResult lb, OpFoldResult ub,
134+
OpFoldResult step);
143135

144136
/// Get back the original induction variable values after loop normalization.
145137
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
464464
return success();
465465
}
466466

467-
LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
468-
OpFoldResult lb, OpFoldResult ub,
469-
OpFoldResult step) {
467+
Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
468+
OpFoldResult lb, OpFoldResult ub,
469+
OpFoldResult step) {
470470
// For non-index types, generate `arith` instructions
471471
// Check if the loop is already known to have a constant zero lower bound or
472472
// a constant one step.
@@ -478,8 +478,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
478478
if (auto stepCst = getConstantIntValue(step))
479479
isStepOne = stepCst.value() == 1;
480480

481-
Type loopParamsType = getType(lb);
482-
assert(loopParamsType == getType(ub) && loopParamsType == getType(step) &&
481+
Type rangeType = getType(lb);
482+
assert(rangeType == getType(ub) && rangeType == getType(step) &&
483483
"expected matching types");
484484

485485
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
@@ -501,8 +501,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
501501
getValueOrCreateConstantIntOp(rewriter, loc, step));
502502
}
503503

504-
OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
505-
OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);
504+
OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
505+
OpFoldResult newStep = rewriter.getOneAttr(rangeType);
506506

507507
return {newLowerBound, newUpperBound, newStep};
508508
}
@@ -626,18 +626,17 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
626626
Value lb = loop.getLowerBound();
627627
Value ub = loop.getUpperBound();
628628
Value step = loop.getStep();
629-
auto newLoopParams =
629+
auto newLoopRange =
630630
emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
631631

632632
rewriter.modifyOpInPlace(loop, [&]() {
633-
loop.setLowerBound(getValueOrCreateConstantIntOp(
634-
rewriter, loop.getLoc(), newLoopParams.lowerBound));
635-
loop.setUpperBound(getValueOrCreateConstantIntOp(
636-
rewriter, loop.getLoc(), newLoopParams.upperBound));
633+
loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
634+
newLoopRange.offset));
635+
loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
636+
newLoopRange.size));
637637
loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
638-
newLoopParams.step));
638+
newLoopRange.stride));
639639
});
640-
641640
rewriter.setInsertionPointToStart(innermost.getBody());
642641
denormalizeInductionVariable(rewriter, loop.getLoc(),
643642
loop.getInductionVar(), lb, step);
@@ -780,9 +779,9 @@ void mlir::collapseParallelLoops(
780779
Value lb = loops.getLowerBound()[i];
781780
Value ub = loops.getUpperBound()[i];
782781
Value step = loops.getStep()[i];
783-
auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
782+
auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
784783
normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
785-
rewriter, loops.getLoc(), newLoopParams.upperBound));
784+
rewriter, loops.getLoc(), newLoopRange.size));
786785

787786
rewriter.setInsertionPointToStart(loops.getBody());
788787
denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,

0 commit comments

Comments
 (0)