-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir][scf]: Removed LoopParams struct and used Range instead (NFC) #95501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Aviad Cohen (AviadCo) ChangesFull diff: https://github.com/llvm/llvm-project/pull/95501.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index f719c00213987..da3fe3ceb86be 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,14 +120,6 @@ LogicalResult loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
-/// This structure is to pass and return sets of loop parameters without
-/// confusing the order.
-struct LoopParams {
- OpFoldResult lowerBound;
- OpFoldResult upperBound;
- OpFoldResult step;
-};
-
/// Transform a loop with a strictly positive step
/// for %i = %lb to %ub step %s
/// into a 0-based loop with step 1
@@ -137,9 +129,9 @@ struct LoopParams {
/// expected to be either `loop` or another loop perfectly nested under `loop`.
/// Insert the definition of new bounds immediate before `outer`, which is
/// expected to be either `loop` or its parent in the loop nest.
-LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
- OpFoldResult lb, OpFoldResult ub,
- OpFoldResult step);
+Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+ OpFoldResult lb, OpFoldResult ub,
+ OpFoldResult step);
/// Get back the original induction variable values after loop normalization.
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a031e53fe0ffb..ff5e3a002263d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -464,9 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
-LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
- OpFoldResult lb, OpFoldResult ub,
- OpFoldResult step) {
+Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+ OpFoldResult lb, OpFoldResult ub,
+ OpFoldResult step) {
// For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
@@ -478,8 +478,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
if (auto stepCst = getConstantIntValue(step))
isStepOne = stepCst.value() == 1;
- Type loopParamsType = getType(lb);
- assert(loopParamsType == getType(ub) && loopParamsType == getType(step) &&
+ Type rangeType = getType(lb);
+ assert(rangeType == getType(ub) && rangeType == getType(step) &&
"expected matching types");
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
@@ -501,8 +501,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
getValueOrCreateConstantIntOp(rewriter, loc, step));
}
- OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
- OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);
+ OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
+ OpFoldResult newStep = rewriter.getOneAttr(rangeType);
return {newLowerBound, newUpperBound, newStep};
}
@@ -626,18 +626,17 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
Value lb = loop.getLowerBound();
Value ub = loop.getUpperBound();
Value step = loop.getStep();
- auto newLoopParams =
+ auto newLoopRange =
emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
rewriter.modifyOpInPlace(loop, [&]() {
- loop.setLowerBound(getValueOrCreateConstantIntOp(
- rewriter, loop.getLoc(), newLoopParams.lowerBound));
- loop.setUpperBound(getValueOrCreateConstantIntOp(
- rewriter, loop.getLoc(), newLoopParams.upperBound));
+ loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+ newLoopRange.offset));
+ loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+ newLoopRange.size));
loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
- newLoopParams.step));
+ newLoopRange.stride));
});
-
rewriter.setInsertionPointToStart(innermost.getBody());
denormalizeInductionVariable(rewriter, loop.getLoc(),
loop.getInductionVar(), lb, step);
@@ -780,9 +779,9 @@ void mlir::collapseParallelLoops(
Value lb = loops.getLowerBound()[i];
Value ub = loops.getUpperBound()[i];
Value step = loops.getStep()[i];
- auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
+ auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
- rewriter, loops.getLoc(), newLoopParams.upperBound));
+ rewriter, loops.getLoc(), newLoopRange.size));
rewriter.setInsertionPointToStart(loops.getBody());
denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!!
Maybe you want to add a test for this? |
@MaheshRavishankar those functions are covered in multiple tests (used by |
No description provided.