Skip to content

[mlir][scf]: Removed LoopParams struct and used Range instead (NFC) #95501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024

Conversation

AviadCo
Copy link
Contributor

@AviadCo AviadCo commented Jun 14, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 14, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Aviad Cohen (AviadCo)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/95501.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+3-11)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+15-16)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index f719c00213987..da3fe3ceb86be 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,14 +120,6 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
-/// This structure is to pass and return sets of loop parameters without
-/// confusing the order.
-struct LoopParams {
-  OpFoldResult lowerBound;
-  OpFoldResult upperBound;
-  OpFoldResult step;
-};
-
 /// Transform a loop with a strictly positive step
 ///   for %i = %lb to %ub step %s
 /// into a 0-based loop with step 1
@@ -137,9 +129,9 @@ struct LoopParams {
 /// expected to be either `loop` or another loop perfectly nested under `loop`.
 /// Insert the definition of new bounds immediate before `outer`, which is
 /// expected to be either `loop` or its parent in the loop nest.
-LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
-                                    OpFoldResult lb, OpFoldResult ub,
-                                    OpFoldResult step);
+Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                               OpFoldResult lb, OpFoldResult ub,
+                               OpFoldResult step);
 
 /// Get back the original induction variable values after loop normalization.
 void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a031e53fe0ffb..ff5e3a002263d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -464,9 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
-LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
-                                          OpFoldResult lb, OpFoldResult ub,
-                                          OpFoldResult step) {
+Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                                     OpFoldResult lb, OpFoldResult ub,
+                                     OpFoldResult step) {
   // For non-index types, generate `arith` instructions
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
@@ -478,8 +478,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
   if (auto stepCst = getConstantIntValue(step))
     isStepOne = stepCst.value() == 1;
 
-  Type loopParamsType = getType(lb);
-  assert(loopParamsType == getType(ub) && loopParamsType == getType(step) &&
+  Type rangeType = getType(lb);
+  assert(rangeType == getType(ub) && rangeType == getType(step) &&
          "expected matching types");
 
   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
@@ -501,8 +501,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
         getValueOrCreateConstantIntOp(rewriter, loc, step));
   }
 
-  OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
-  OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);
+  OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
+  OpFoldResult newStep = rewriter.getOneAttr(rangeType);
 
   return {newLowerBound, newUpperBound, newStep};
 }
@@ -626,18 +626,17 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
     Value lb = loop.getLowerBound();
     Value ub = loop.getUpperBound();
     Value step = loop.getStep();
-    auto newLoopParams =
+    auto newLoopRange =
         emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
 
     rewriter.modifyOpInPlace(loop, [&]() {
-      loop.setLowerBound(getValueOrCreateConstantIntOp(
-          rewriter, loop.getLoc(), newLoopParams.lowerBound));
-      loop.setUpperBound(getValueOrCreateConstantIntOp(
-          rewriter, loop.getLoc(), newLoopParams.upperBound));
+      loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+                                                       newLoopRange.offset));
+      loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+                                                       newLoopRange.size));
       loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
-                                                 newLoopParams.step));
+                                                 newLoopRange.stride));
     });
-
     rewriter.setInsertionPointToStart(innermost.getBody());
     denormalizeInductionVariable(rewriter, loop.getLoc(),
                                  loop.getInductionVar(), lb, step);
@@ -780,9 +779,9 @@ void mlir::collapseParallelLoops(
     Value lb = loops.getLowerBound()[i];
     Value ub = loops.getUpperBound()[i];
     Value step = loops.getStep()[i];
-    auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
+    auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
     normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
-        rewriter, loops.getLoc(), newLoopParams.upperBound));
+        rewriter, loops.getLoc(), newLoopRange.size));
 
     rewriter.setInsertionPointToStart(loops.getBody());
     denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!!

@MaheshRavishankar
Copy link
Contributor

Maybe you want to add a test for this?

@AviadCo
Copy link
Contributor Author

AviadCo commented Jun 14, 2024

@MaheshRavishankar those functions are covered in multiple tests (used by LoopCoalesceOp transform).
In addition, I have this PR which I am going to rebase and has many more test cases using those functions (also unify the struct with affine one), will update you once the next PR is rebased, thanks for the review!

@AviadCo AviadCo merged commit 2ecb1ab into llvm:main Jun 14, 2024
10 checks passed
@AviadCo AviadCo deleted the scf/refactor-LoopParams branch June 14, 2024 18:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants