-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][scf] add unroll-full option to test-loop-unrolling pass. #127158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][scf] add unroll-full option to test-loop-unrolling pass. #127158
Conversation
@llvm/pr-subscribers-mlir Author: lonely eagle (linuxlonelyeagle) ChangesSome loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added. Full diff: https://github.com/llvm/llvm-project/pull/127158.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 02ffa0da7a8b8..c0c11c9e38994 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -126,6 +126,9 @@ FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
+/// Unrolls this loop completely.
+LogicalResult loopUnrollFull(scf::ForOp forOp);
+
/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
/// Returns failure if the loop cannot be unrolled either due to restrictions or
/// due to invalid unroll factors. In case of unroll factor of 1, the function
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index fa82bcb816a2a..0ee325f6c0439 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -498,6 +498,21 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
return resultLoops;
}
+/// Unrolls this loop completely.
+LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
+ IRRewriter rewriter(forOp.getContext());
+ std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ if (mayBeConstantTripCount.has_value()) {
+ uint64_t tripCount = *mayBeConstantTripCount;
+ if (tripCount == 0)
+ return success();
+ if (tripCount == 1)
+ return forOp.promoteIfSingleIteration(rewriter);
+ return loopUnrollByFactor(forOp, tripCount);
+ }
+ return failure();
+}
+
/// Check if bounds of all inner loops are defined outside of `forOp`
/// and return false if not.
static bool areInnerBoundsInvariant(scf::ForOp forOp) {
diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index baf6b2970ac0e..75481863795ae 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
// CHECK-LABEL: scf_loop_unroll_single
func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
// UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
// UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single(
+// UNROLL-FULL-SAME: %[[VAL_0:.*]]: index) -> index {
+func.func @scf_loop_unroll_full_single(%arg : index) -> index {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %2 = arith.constant 4 : index
+ %4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index {
+ %3 = arith.addi %arg1, %arg : index
+ scf.yield %3 : index
+ }
+ return %4 : index
+ // UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 1 : index
+ // UNROLL-FULL: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : index
+ // UNROLL-FULL: %[[VAL_3:.*]] = arith.addi %[[VAL_2]], %[[VAL_0]] : index
+ // UNROLL-FULL: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : index
+ // UNROLL-FULL: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+ // UNROLL-FULL: return %[[VAL_5]] : index
+}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops(
+// UNROLL-FULL-SAME: %[[VAL_0:.*]]: vector<4x4xindex>) -> index {
+func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %2 = arith.constant 4 : index
+ %6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index {
+ %5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index {
+ %3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex>
+ %4 = arith.addi %3, %it1 : index
+ scf.yield %3 : index
+ }
+ scf.yield %5 : index
+ }
+ return %6 : index
+ // UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 0 : index
+ // UNROLL-FULL: %[[VAL_2:.*]] = arith.constant 1 : index
+ // UNROLL-FULL: %[[VAL_3:.*]] = arith.constant 4 : index
+ // UNROLL-FULL: %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]]) -> (index) {
+ // UNROLL-FULL: %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, %[[VAL_5]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL_7]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_10:.*]] = %[[VAL_4]]) -> (index) {
+ // UNROLL-FULL: %[[VAL_11:.*]] = vector.extract %[[VAL_0]][1, %[[VAL_9]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL_11]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_8]]) -> (index) {
+ // UNROLL-FULL: %[[VAL_15:.*]] = vector.extract %[[VAL_0]][2, %[[VAL_13]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL_15]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_12]]) -> (index) {
+ // UNROLL-FULL: %[[VAL_19:.*]] = vector.extract %[[VAL_0]][3, %[[VAL_17]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL_19]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: return %[[VAL_16]] : index
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
index 8694a7f9bbd62..cb554c3dfb66c 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
@@ -42,10 +42,11 @@ struct TestLoopUnrollingPass
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
unsigned loopDepthParam,
- bool annotateLoopParam) {
+ bool annotateLoopParam, bool unrollFullParam) {
unrollFactor = unrollFactorParam;
loopDepth = loopDepthParam;
annotateLoop = annotateLoopParam;
+ unrollFull = unrollFactorParam;
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -63,8 +64,12 @@ struct TestLoopUnrollingPass
op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
}
};
- for (auto loop : loops)
- (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+ for (auto loop : loops) {
+ if (unrollFull)
+ loopUnrollFull(loop);
+ else
+ (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+ }
}
Option<uint64_t> unrollFactor{*this, "unroll-factor",
llvm::cl::desc("Loop unroll factor."),
@@ -77,6 +82,9 @@ struct TestLoopUnrollingPass
llvm::cl::init(false)};
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
llvm::cl::init(0)};
+ Option<bool> unrollFull{*this, "unroll-full",
+ llvm::cl::desc("Full unroll loops."),
+ llvm::cl::init(false)};
};
} // namespace
|
@llvm/pr-subscribers-mlir-scf Author: lonely eagle (linuxlonelyeagle) ChangesSome loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added. Full diff: https://github.com/llvm/llvm-project/pull/127158.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 02ffa0da7a8b8..c0c11c9e38994 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -126,6 +126,9 @@ FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
+/// Unrolls this loop completely.
+LogicalResult loopUnrollFull(scf::ForOp forOp);
+
/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
/// Returns failure if the loop cannot be unrolled either due to restrictions or
/// due to invalid unroll factors. In case of unroll factor of 1, the function
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index fa82bcb816a2a..0ee325f6c0439 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -498,6 +498,21 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
return resultLoops;
}
+/// Unrolls this loop completely.
+LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
+ IRRewriter rewriter(forOp.getContext());
+ std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ if (mayBeConstantTripCount.has_value()) {
+ uint64_t tripCount = *mayBeConstantTripCount;
+ if (tripCount == 0)
+ return success();
+ if (tripCount == 1)
+ return forOp.promoteIfSingleIteration(rewriter);
+ return loopUnrollByFactor(forOp, tripCount);
+ }
+ return failure();
+}
+
/// Check if bounds of all inner loops are defined outside of `forOp`
/// and return false if not.
static bool areInnerBoundsInvariant(scf::ForOp forOp) {
diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir
index baf6b2970ac0e..75481863795ae 100644
--- a/mlir/test/Transforms/scf-loop-unroll.mlir
+++ b/mlir/test/Transforms/scf-loop-unroll.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
+// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
// CHECK-LABEL: scf_loop_unroll_single
func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
// UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
// UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single(
+// UNROLL-FULL-SAME: %[[VAL_0:.*]]: index) -> index {
+func.func @scf_loop_unroll_full_single(%arg : index) -> index {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %2 = arith.constant 4 : index
+ %4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index {
+ %3 = arith.addi %arg1, %arg : index
+ scf.yield %3 : index
+ }
+ return %4 : index
+ // UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 1 : index
+ // UNROLL-FULL: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]] : index
+ // UNROLL-FULL: %[[VAL_3:.*]] = arith.addi %[[VAL_2]], %[[VAL_0]] : index
+ // UNROLL-FULL: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : index
+ // UNROLL-FULL: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_0]] : index
+ // UNROLL-FULL: return %[[VAL_5]] : index
+}
+
+// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops(
+// UNROLL-FULL-SAME: %[[VAL_0:.*]]: vector<4x4xindex>) -> index {
+func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 1 : index
+ %2 = arith.constant 4 : index
+ %6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index {
+ %5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index {
+ %3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex>
+ %4 = arith.addi %3, %it1 : index
+ scf.yield %3 : index
+ }
+ scf.yield %5 : index
+ }
+ return %6 : index
+ // UNROLL-FULL: %[[VAL_1:.*]] = arith.constant 0 : index
+ // UNROLL-FULL: %[[VAL_2:.*]] = arith.constant 1 : index
+ // UNROLL-FULL: %[[VAL_3:.*]] = arith.constant 4 : index
+ // UNROLL-FULL: %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]]) -> (index) {
+ // UNROLL-FULL: %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, %[[VAL_5]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL_7]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_10:.*]] = %[[VAL_4]]) -> (index) {
+ // UNROLL-FULL: %[[VAL_11:.*]] = vector.extract %[[VAL_0]][1, %[[VAL_9]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL_11]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_8]]) -> (index) {
+ // UNROLL-FULL: %[[VAL_15:.*]] = vector.extract %[[VAL_0]][2, %[[VAL_13]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL_15]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_12]]) -> (index) {
+ // UNROLL-FULL: %[[VAL_19:.*]] = vector.extract %[[VAL_0]][3, %[[VAL_17]]] : index from vector<4x4xindex>
+ // UNROLL-FULL: scf.yield %[[VAL_19]] : index
+ // UNROLL-FULL: }
+ // UNROLL-FULL: return %[[VAL_16]] : index
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
index 8694a7f9bbd62..cb554c3dfb66c 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp
@@ -42,10 +42,11 @@ struct TestLoopUnrollingPass
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
unsigned loopDepthParam,
- bool annotateLoopParam) {
+ bool annotateLoopParam, bool unrollFullParam) {
unrollFactor = unrollFactorParam;
loopDepth = loopDepthParam;
annotateLoop = annotateLoopParam;
+ unrollFull = unrollFactorParam;
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -63,8 +64,12 @@ struct TestLoopUnrollingPass
op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
}
};
- for (auto loop : loops)
- (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+ for (auto loop : loops) {
+ if (unrollFull)
+ loopUnrollFull(loop);
+ else
+ (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
+ }
}
Option<uint64_t> unrollFactor{*this, "unroll-factor",
llvm::cl::desc("Loop unroll factor."),
@@ -77,6 +82,9 @@ struct TestLoopUnrollingPass
llvm::cl::init(false)};
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
llvm::cl::init(0)};
+ Option<bool> unrollFull{*this, "unroll-full",
+ llvm::cl::desc("Full unroll loops."),
+ llvm::cl::init(false)};
};
} // namespace
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after comments are addressed.
Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
Ping @ftynse Although you have approved it, I think you still need to do a final review, thank you. |
…127158) Some loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added. --------- Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
Some loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added.