Skip to content

Commit b227c25

Browse files
[mlir][scf] add unroll-full option to test-loop-unrolling pass (llvm#127158)
Some loops cannot be unrolled by affine-loop-unroll pass. After running lower-affine pass, they can be unrolled in scf.To enable conversion of vector Ops in scf to llvm dialect, unroll-full option was added. --------- Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
1 parent 251377c commit b227c25

File tree

4 files changed

+85
-3
lines changed

4 files changed

+85
-3
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
126126
scf::ForOp forOp, uint64_t unrollFactor,
127127
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
128128

129+
/// Unrolls this loop completely.
130+
LogicalResult loopUnrollFull(scf::ForOp forOp);
131+
129132
/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
130133
/// Returns failure if the loop cannot be unrolled either due to restrictions or
131134
/// due to invalid unroll factors. In case of unroll factor of 1, the function

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,20 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
498498
return resultLoops;
499499
}
500500

501+
/// Unrolls this loop completely.
502+
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
503+
IRRewriter rewriter(forOp.getContext());
504+
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
505+
if (!mayBeConstantTripCount.has_value())
506+
return failure();
507+
uint64_t tripCount = *mayBeConstantTripCount;
508+
if (tripCount == 0)
509+
return success();
510+
if (tripCount == 1)
511+
return forOp.promoteIfSingleIteration(rewriter);
512+
return loopUnrollByFactor(forOp, tripCount);
513+
}
514+
501515
/// Check if bounds of all inner loops are defined outside of `forOp`
502516
/// and return false if not.
503517
static bool areInnerBoundsInvariant(scf::ForOp forOp) {

mlir/test/Transforms/scf-loop-unroll.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s
22
// RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1
3+
// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL
34

45
// CHECK-LABEL: scf_loop_unroll_single
56
func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 {
@@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () {
5657
// UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index
5758
// UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32
5859
}
60+
61+
// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single
62+
// UNROLL-FULL-SAME: %[[ARG:.*]]: index)
63+
func.func @scf_loop_unroll_full_single(%arg : index) -> index {
64+
%0 = arith.constant 0 : index
65+
%1 = arith.constant 1 : index
66+
%2 = arith.constant 4 : index
67+
%4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index {
68+
%3 = arith.addi %arg1, %arg : index
69+
scf.yield %3 : index
70+
}
71+
return %4 : index
72+
// UNROLL-FULL: %[[C1:.*]] = arith.constant 1 : index
73+
// UNROLL-FULL: %[[V0:.*]] = arith.addi %[[ARG]], %[[C1]] : index
74+
// UNROLL-FULL: %[[V1:.*]] = arith.addi %[[V0]], %[[ARG]] : index
75+
// UNROLL-FULL: %[[V2:.*]] = arith.addi %[[V1]], %[[ARG]] : index
76+
// UNROLL-FULL: %[[V3:.*]] = arith.addi %[[V2]], %[[ARG]] : index
77+
// UNROLL-FULL: return %[[V3]] : index
78+
}
79+
80+
// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops
81+
// UNROLL-FULL-SAME: %[[ARG:.*]]: vector<4x4xindex>)
82+
func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index {
83+
%0 = arith.constant 0 : index
84+
%1 = arith.constant 1 : index
85+
%2 = arith.constant 4 : index
86+
%6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index {
87+
%5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index {
88+
%3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex>
89+
%4 = arith.addi %3, %it1 : index
90+
scf.yield %3 : index
91+
}
92+
scf.yield %5 : index
93+
}
94+
return %6 : index
95+
// UNROLL-FULL: %[[C0:.*]] = arith.constant 0 : index
96+
// UNROLL-FULL: %[[C1:.*]] = arith.constant 1 : index
97+
// UNROLL-FULL: %[[C4:.*]] = arith.constant 4 : index
98+
// UNROLL-FULL: %[[SUM0:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[C0]])
99+
// UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][0, %[[IV]]] : index from vector<4x4xindex>
100+
// UNROLL-FULL: scf.yield %[[VAL]] : index
101+
// UNROLL-FULL: }
102+
// UNROLL-FULL: %[[SUM1:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM0]])
103+
// UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][1, %[[IV]]] : index from vector<4x4xindex>
104+
// UNROLL-FULL: scf.yield %[[VAL]] : index
105+
// UNROLL-FULL: }
106+
// UNROLL-FULL: %[[SUM2:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM1]])
107+
// UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][2, %[[IV]]] : index from vector<4x4xindex>
108+
// UNROLL-FULL: scf.yield %[[VAL]] : index
109+
// UNROLL-FULL: }
110+
// UNROLL-FULL: %[[SUM3:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM2]])
111+
// UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][3, %[[IV]]] : index from vector<4x4xindex>
112+
// UNROLL-FULL: scf.yield %[[VAL]] : index
113+
// UNROLL-FULL: }
114+
// UNROLL-FULL: return %[[SUM3]] : index
115+
}

mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ struct TestLoopUnrollingPass
4242
TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
4343
explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
4444
unsigned loopDepthParam,
45-
bool annotateLoopParam) {
45+
bool annotateLoopParam, bool unrollFullParam) {
4646
unrollFactor = unrollFactorParam;
4747
loopDepth = loopDepthParam;
4848
annotateLoop = annotateLoopParam;
49+
unrollFull = unrollFactorParam;
4950
}
5051

5152
void getDependentDialects(DialectRegistry &registry) const override {
@@ -63,8 +64,12 @@ struct TestLoopUnrollingPass
6364
op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
6465
}
6566
};
66-
for (auto loop : loops)
67-
(void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
67+
for (auto loop : loops) {
68+
if (unrollFull)
69+
(void)loopUnrollFull(loop);
70+
else
71+
(void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
72+
}
6873
}
6974
Option<uint64_t> unrollFactor{*this, "unroll-factor",
7075
llvm::cl::desc("Loop unroll factor."),
@@ -77,6 +82,9 @@ struct TestLoopUnrollingPass
7782
llvm::cl::init(false)};
7883
Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
7984
llvm::cl::init(0)};
85+
Option<bool> unrollFull{*this, "unroll-full",
86+
llvm::cl::desc("Full unroll loops."),
87+
llvm::cl::init(false)};
8088
};
8189
} // namespace
8290

0 commit comments

Comments
 (0)