Skip to content

[mlir][scf] add unroll-full option to test-loop-unrolling pass. #127158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 18, 2025

Conversation

linuxlonelyeagle
Copy link
Member

Some loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added.

@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

Changes

Some loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added.


Full diff: https://github.com/llvm/llvm-project/pull/127158.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+3)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+15)
  • (modified) mlir/test/Transforms/scf-loop-unroll.mlir (+57)
  • (modified) mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp (+11-3)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 02ffa0da7a8b8..c0c11c9e38994 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -126,6 +126,9 @@ FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls this loop completely.
+LogicalResult loopUnrollFull(scf::ForOp forOp);
+
 /// Unrolls and jams this `scf.for` operation by the specified unroll factor.
 /// Returns failure if the loop cannot be unrolled either due to restrictions or
 /// due to invalid unroll factors. In case of unroll factor of 1, the function
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index fa82bcb816a2a..0ee325f6c0439 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -498,6 +498,21 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
   return resultLoops;
 }
 
+/// Unrolls this loop completely.
+LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
+  IRRewriter rewriter(forOp.getContext());
+  std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  if (mayBeConstantTripCount.has_value()) {
+    uint64_t tripCount = *mayBeConstantTripCount;
+    if (tripCount == 0)
+      return success();
+    if (tripCount == 1)
+      return forOp.promoteIfSingleIteration(rewriter);
+    return loopUnrollByFactor(forOp, tripCount);
+  }
+  return failure();
+}
+
 /// Check if bounds of all inner loops are defined outside of `forOp`
 /// and return false if not.
 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index baf6b2970ac0e..75481863795ae 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
 // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
 
 // CHECK-LABEL: scf_loop_unroll_single
 func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
   // UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
   // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
 }
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single(
+// UNROLL-FULL-SAME:    %[[VAL_0:.*]]: index) -> index {
+func.func @scf_loop_unroll_full_single(%arg : index) -> index {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 4 : index
+  %4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index {
+    %3 = arith.addi %arg1, %arg : index
+    scf.yield %3 : index
+  }
+  return %4 : index
+  // UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 1 : index
+  // UNROLL-FULL: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : index
+  // UNROLL-FULL: %[[VAL_3:.*]] = arith.addi %[[VAL_2]], %[[VAL_0]] : index
+  // UNROLL-FULL: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : index
+  // UNROLL-FULL: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+  // UNROLL-FULL: return %[[VAL_5]] : index
+}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops(
+// UNROLL-FULL-SAME:    %[[VAL_0:.*]]: vector<4x4xindex>) -> index {
+func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 4 : index
+  %6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index {
+    %5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index {
+      %3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex>
+      %4 = arith.addi %3, %it1 : index
+      scf.yield %3 : index
+    }
+    scf.yield %5 : index
+  }
+  return %6 : index
+  // UNROLL-FULL:   %[[VAL_1:.*]] = arith.constant 0 : index
+  // UNROLL-FULL:   %[[VAL_2:.*]] = arith.constant 1 : index
+  // UNROLL-FULL:   %[[VAL_3:.*]] = arith.constant 4 : index
+  // UNROLL-FULL:   %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, %[[VAL_5]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_7]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_10:.*]] = %[[VAL_4]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_11:.*]] = vector.extract %[[VAL_0]][1, %[[VAL_9]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_11]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_8]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_15:.*]] = vector.extract %[[VAL_0]][2, %[[VAL_13]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_15]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_12]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_19:.*]] = vector.extract %[[VAL_0]][3, %[[VAL_17]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_19]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL: return %[[VAL_16]] : index
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
index 8694a7f9bbd62..cb554c3dfb66c 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
@@ -42,10 +42,11 @@ struct TestLoopUnrollingPass
   TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
   explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
                                  unsigned loopDepthParam,
-                                 bool annotateLoopParam) {
+                                 bool annotateLoopParam, bool unrollFullParam) {
     unrollFactor = unrollFactorParam;
     loopDepth = loopDepthParam;
     annotateLoop = annotateLoopParam;
+    unrollFull = unrollFactorParam;
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -63,8 +64,12 @@ struct TestLoopUnrollingPass
         op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
       }
     };
-    for (auto loop : loops)
-      (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+    for (auto loop : loops) {
+      if (unrollFull)
+        loopUnrollFull(loop);
+      else
+        (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+    }
   }
   Option<uint64_t> unrollFactor{*this, "unroll-factor",
                                 llvm::cl::desc("Loop unroll factor."),
@@ -77,6 +82,9 @@ struct TestLoopUnrollingPass
                                 llvm::cl::init(false)};
   Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
                              llvm::cl::init(0)};
+  Option<bool> unrollFull{*this, "unroll-full",
+                          llvm::cl::desc("Full unroll loops."),
+                          llvm::cl::init(false)};
 };
 } // namespace
 

@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-mlir-scf

Author: lonely eagle (linuxlonelyeagle)

Changes

Some loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added.


Full diff: https://github.com/llvm/llvm-project/pull/127158.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+3)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+15)
  • (modified) mlir/test/Transforms/scf-loop-unroll.mlir (+57)
  • (modified) mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp (+11-3)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 02ffa0da7a8b8..c0c11c9e38994 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -126,6 +126,9 @@ FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls this loop completely.
+LogicalResult loopUnrollFull(scf::ForOp forOp);
+
 /// Unrolls and jams this `scf.for` operation by the specified unroll factor.
 /// Returns failure if the loop cannot be unrolled either due to restrictions or
 /// due to invalid unroll factors. In case of unroll factor of 1, the function
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index fa82bcb816a2a..0ee325f6c0439 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -498,6 +498,21 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
   return resultLoops;
 }
 
+/// Unrolls this loop completely.
+LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
+  IRRewriter rewriter(forOp.getContext());
+  std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  if (mayBeConstantTripCount.has_value()) {
+    uint64_t tripCount = *mayBeConstantTripCount;
+    if (tripCount == 0)
+      return success();
+    if (tripCount == 1)
+      return forOp.promoteIfSingleIteration(rewriter);
+    return loopUnrollByFactor(forOp, tripCount);
+  }
+  return failure();
+}
+
 /// Check if bounds of all inner loops are defined outside of `forOp`
 /// and return false if not.
 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index baf6b2970ac0e..75481863795ae 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
 // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
 
 // CHECK-LABEL: scf_loop_unroll_single
 func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
   // UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
   // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
 }
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single(
+// UNROLL-FULL-SAME:    %[[VAL_0:.*]]: index) -> index {
+func.func @scf_loop_unroll_full_single(%arg : index) -> index {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 4 : index
+  %4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index {
+    %3 = arith.addi %arg1, %arg : index
+    scf.yield %3 : index
+  }
+  return %4 : index
+  // UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 1 : index
+  // UNROLL-FULL: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : index
+  // UNROLL-FULL: %[[VAL_3:.*]] = arith.addi %[[VAL_2]], %[[VAL_0]] : index
+  // UNROLL-FULL: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : index
+  // UNROLL-FULL: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+  // UNROLL-FULL: return %[[VAL_5]] : index
+}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops(
+// UNROLL-FULL-SAME:    %[[VAL_0:.*]]: vector<4x4xindex>) -> index {
+func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 4 : index
+  %6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index {
+    %5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index {
+      %3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex>
+      %4 = arith.addi %3, %it1 : index
+      scf.yield %3 : index
+    }
+    scf.yield %5 : index
+  }
+  return %6 : index
+  // UNROLL-FULL:   %[[VAL_1:.*]] = arith.constant 0 : index
+  // UNROLL-FULL:   %[[VAL_2:.*]] = arith.constant 1 : index
+  // UNROLL-FULL:   %[[VAL_3:.*]] = arith.constant 4 : index
+  // UNROLL-FULL:   %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, %[[VAL_5]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_7]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_10:.*]] = %[[VAL_4]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_11:.*]] = vector.extract %[[VAL_0]][1, %[[VAL_9]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_11]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_8]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_15:.*]] = vector.extract %[[VAL_0]][2, %[[VAL_13]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_15]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL:   %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_12]]) -> (index) {
+  // UNROLL-FULL:   %[[VAL_19:.*]] = vector.extract %[[VAL_0]][3, %[[VAL_17]]] : index from vector<4x4xindex>
+  // UNROLL-FULL:   scf.yield %[[VAL_19]] : index
+  // UNROLL-FULL: }
+  // UNROLL-FULL: return %[[VAL_16]] : index
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
index 8694a7f9bbd62..cb554c3dfb66c 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
@@ -42,10 +42,11 @@ struct TestLoopUnrollingPass
   TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
   explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
                                  unsigned loopDepthParam,
-                                 bool annotateLoopParam) {
+                                 bool annotateLoopParam, bool unrollFullParam) {
     unrollFactor = unrollFactorParam;
     loopDepth = loopDepthParam;
     annotateLoop = annotateLoopParam;
+    unrollFull = unrollFactorParam;
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -63,8 +64,12 @@ struct TestLoopUnrollingPass
         op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
       }
     };
-    for (auto loop : loops)
-      (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+    for (auto loop : loops) {
+      if (unrollFull)
+        loopUnrollFull(loop);
+      else
+        (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+    }
   }
   Option<uint64_t> unrollFactor{*this, "unroll-factor",
                                 llvm::cl::desc("Loop unroll factor."),
@@ -77,6 +82,9 @@ struct TestLoopUnrollingPass
                                 llvm::cl::init(false)};
   Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
                              llvm::cl::init(0)};
+  Option<bool> unrollFull{*this, "unroll-full",
+                          llvm::cl::desc("Full unroll loops."),
+                          llvm::cl::init(false)};
 };
 } // namespace
 

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after comments are addressed.

linuxlonelyeagle and others added 2 commits February 17, 2025 17:13
Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
@linuxlonelyeagle
Copy link
Member Author

Ping @ftynse Although you have approved it, I think you still need to do a final review, thank you.

@linuxlonelyeagle linuxlonelyeagle merged commit b227c25 into llvm:main Feb 18, 2025
8 checks passed
wldfngrs pushed a commit to wldfngrs/llvm-project that referenced this pull request Feb 19, 2025
…127158)

Some loops cannot be unrolled by affine-loop-unroll pass. After running
lower-affine pass, they can be unrolled in scf.To enable conversion of
vector Ops in scf to llvm dialect, unroll-full option was added.

---------

Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants