Skip to content

[mlir][scf]: Removed LoopParams struct and used Range instead (NFC) #95501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,6 @@ LogicalResult loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);

/// This structure is to pass and return sets of loop parameters without
/// confusing the order.
struct LoopParams {
OpFoldResult lowerBound;
OpFoldResult upperBound;
OpFoldResult step;
};

/// Transform a loop with a strictly positive step
/// for %i = %lb to %ub step %s
/// into a 0-based loop with step 1
Expand All @@ -137,9 +129,9 @@ struct LoopParams {
/// expected to be either `loop` or another loop perfectly nested under `loop`.
/// Insert the definition of new bounds immediate before `outer`, which is
/// expected to be either `loop` or its parent in the loop nest.
LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step);
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step);

/// Get back the original induction variable values after loop normalization.
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
Expand Down
31 changes: 15 additions & 16 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}

LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
// For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
Expand All @@ -478,8 +478,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
if (auto stepCst = getConstantIntValue(step))
isStepOne = stepCst.value() == 1;

Type loopParamsType = getType(lb);
assert(loopParamsType == getType(ub) && loopParamsType == getType(step) &&
Type rangeType = getType(lb);
assert(rangeType == getType(ub) && rangeType == getType(step) &&
"expected matching types");

// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
Expand All @@ -501,8 +501,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
getValueOrCreateConstantIntOp(rewriter, loc, step));
}

OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);
OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
OpFoldResult newStep = rewriter.getOneAttr(rangeType);

return {newLowerBound, newUpperBound, newStep};
}
Expand Down Expand Up @@ -626,18 +626,17 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
Value lb = loop.getLowerBound();
Value ub = loop.getUpperBound();
Value step = loop.getStep();
auto newLoopParams =
auto newLoopRange =
emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);

rewriter.modifyOpInPlace(loop, [&]() {
loop.setLowerBound(getValueOrCreateConstantIntOp(
rewriter, loop.getLoc(), newLoopParams.lowerBound));
loop.setUpperBound(getValueOrCreateConstantIntOp(
rewriter, loop.getLoc(), newLoopParams.upperBound));
loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
newLoopRange.offset));
loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
newLoopRange.size));
loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
newLoopParams.step));
newLoopRange.stride));
});

rewriter.setInsertionPointToStart(innermost.getBody());
denormalizeInductionVariable(rewriter, loop.getLoc(),
loop.getInductionVar(), lb, step);
Expand Down Expand Up @@ -780,9 +779,9 @@ void mlir::collapseParallelLoops(
Value lb = loops.getLowerBound()[i];
Value ub = loops.getUpperBound()[i];
Value step = loops.getStep()[i];
auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
rewriter, loops.getLoc(), newLoopParams.upperBound));
rewriter, loops.getLoc(), newLoopRange.size));

rewriter.setInsertionPointToStart(loops.getBody());
denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
Expand Down
Loading