Skip to content

Commit 644b55d

Browse files
committed
[MLIR][SCF] Add for-to-while loop transformation pass
This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop condition is placed in the 'before' region of the while operation, and indctuion variable incrementation + the loop body in the 'after' region. The loop carried values of the while op are the induction variable (IV) of the for-loop + any iter_args specified for the for-loop. Any 'yield' ops in the for-loop are rewritten to additionally yield the (incremented) induction variable. This transformation is useful for passes where we want to consider structured control flow solely on the basis of a loop body and the computation of a loop condition. As an example, when doing high-level synthesis in CIRCT, the incrementation of an IV in a for-loop is "just another part" of a circuit datapath, and what we really care about is the distinction between our datapath and our control logic (the condition variable). Differential Revision: https://reviews.llvm.org/D108454
1 parent 09100c7 commit 644b55d

File tree

5 files changed

+297
-0
lines changed

5 files changed

+297
-0
lines changed

mlir/include/mlir/Dialect/SCF/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
5252
/// loop range.
5353
std::unique_ptr<Pass> createForLoopRangeFoldingPass();
5454

55+
// Creates a pass which lowers for loops into while loops.
56+
std::unique_ptr<Pass> createForToWhileLoopPass();
57+
5558
//===----------------------------------------------------------------------===//
5659
// Registration
5760
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SCF/Passes.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,39 @@ def SCFForLoopRangeFolding
7878
let constructor = "mlir::createForLoopRangeFoldingPass()";
7979
}
8080

81+
def SCFForToWhileLoop
82+
: FunctionPass<"scf-for-to-while"> {
83+
let summary = "Convert SCF for loops to SCF while loops";
84+
let constructor = "mlir::createForToWhileLoopPass()";
85+
let description = [{
86+
This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop
87+
condition is placed in the 'before' region of the while operation, and the
88+
induction variable incrementation and loop body in the 'after' region.
89+
The loop carried values of the while op are the induction variable (IV) of
90+
the for-loop + any iter_args specified for the for-loop.
91+
Any 'yield' ops in the for-loop are rewritten to additionally yield the
92+
(incremented) induction variable.
93+
94+
```mlir
95+
# Before:
96+
scf.for %i = %c0 to %arg1 step %c1 {
97+
%0 = addi %arg2, %arg2 : i32
98+
memref.store %0, %arg0[%i] : memref<?xi32>
99+
}
100+
101+
# After:
102+
%0 = scf.while (%i = %c0) : (index) -> index {
103+
%1 = cmpi slt, %i, %arg1 : index
104+
scf.condition(%1) %i : index
105+
} do {
106+
^bb0(%i: index): // no predecessors
107+
%1 = addi %i, %c1 : index
108+
%2 = addi %arg2, %arg2 : i32
109+
memref.store %2, %arg0[%i] : memref<?xi32>
110+
scf.yield %1 : index
111+
}
112+
```
113+
}];
114+
}
115+
81116
#endif // MLIR_DIALECT_SCF_PASSES

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRSCFTransforms
22
Bufferize.cpp
3+
ForToWhile.cpp
34
LoopCanonicalization.cpp
45
LoopPipelining.cpp
56
LoopRangeFolding.cpp
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.ForOp's into SCF.WhileOp's.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "PassDetail.h"
14+
#include "mlir/Dialect/SCF/Passes.h"
15+
#include "mlir/Dialect/SCF/SCF.h"
16+
#include "mlir/Dialect/SCF/Transforms.h"
17+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
21+
using namespace llvm;
22+
using namespace mlir;
23+
using scf::ForOp;
24+
using scf::WhileOp;
25+
26+
namespace {
27+
28+
struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
29+
using OpRewritePattern<ForOp>::OpRewritePattern;
30+
31+
LogicalResult matchAndRewrite(ForOp forOp,
32+
PatternRewriter &rewriter) const override {
33+
// Generate type signature for the loop-carried values. The induction
34+
// variable is placed first, followed by the forOp.iterArgs.
35+
SmallVector<Type, 8> lcvTypes;
36+
lcvTypes.push_back(forOp.getInductionVar().getType());
37+
llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes),
38+
[&](auto v) { return v.getType(); });
39+
40+
// Build scf.WhileOp
41+
SmallVector<Value> initArgs;
42+
initArgs.push_back(forOp.lowerBound());
43+
llvm::append_range(initArgs, forOp.initArgs());
44+
auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
45+
forOp->getAttrs());
46+
47+
// 'before' region contains the loop condition and forwarding of iteration
48+
// arguments to the 'after' region.
49+
auto *beforeBlock = rewriter.createBlock(
50+
&whileOp.before(), whileOp.before().begin(), lcvTypes, {});
51+
rewriter.setInsertionPointToStart(&whileOp.before().front());
52+
auto cmpOp = rewriter.create<CmpIOp>(whileOp.getLoc(), CmpIPredicate::slt,
53+
beforeBlock->getArgument(0),
54+
forOp.upperBound());
55+
rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
56+
beforeBlock->getArguments());
57+
58+
// Inline for-loop body into an executeRegion operation in the "after"
59+
// region. The return type of the execRegionOp does not contain the
60+
// iv - yields in the source for-loop contain only iterArgs.
61+
auto *afterBlock = rewriter.createBlock(
62+
&whileOp.after(), whileOp.after().begin(), lcvTypes, {});
63+
64+
// Add induction variable incrementation
65+
rewriter.setInsertionPointToEnd(afterBlock);
66+
auto ivIncOp = rewriter.create<AddIOp>(
67+
whileOp.getLoc(), afterBlock->getArgument(0), forOp.step());
68+
69+
// Rewrite uses of the for-loop block arguments to the new while-loop
70+
// "after" arguments
71+
for (auto barg : enumerate(forOp.getBody(0)->getArguments()))
72+
barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
73+
74+
// Inline for-loop body operations into 'after' region.
75+
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
76+
arg.moveBefore(afterBlock, afterBlock->end());
77+
78+
// Add incremented IV to yield operations
79+
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
80+
SmallVector<Value> yieldOperands = yieldOp.getOperands();
81+
yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
82+
yieldOp->setOperands(yieldOperands);
83+
}
84+
85+
// We cannot do a direct replacement of the forOp since the while op returns
86+
// an extra value (the induction variable escapes the loop through being
87+
// carried in the set of iterargs). Instead, rewrite uses of the forOp
88+
// results.
89+
for (auto arg : llvm::enumerate(forOp.getResults()))
90+
arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
91+
92+
rewriter.eraseOp(forOp);
93+
return success();
94+
}
95+
};
96+
97+
struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
98+
void runOnFunction() override {
99+
FuncOp funcOp = getFunction();
100+
MLIRContext *ctx = funcOp.getContext();
101+
RewritePatternSet patterns(ctx);
102+
patterns.add<ForLoopLoweringPattern>(ctx);
103+
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
104+
}
105+
};
106+
} // namespace
107+
108+
std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
109+
return std::make_unique<ForToWhileLoop>();
110+
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-for-to-while)' -split-input-file | FileCheck %s
2+
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
3+
4+
// CHECK-LABEL: builtin.func @single_loop(
5+
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xi32>,
6+
// CHECK-SAME: %[[VAL_1:.*]]: index,
7+
// CHECK-SAME: %[[VAL_2:.*]]: i32) {
8+
// CHECK: %[[VAL_3:.*]] = constant 0 : index
9+
// CHECK: %[[VAL_4:.*]] = constant 1 : index
10+
// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index {
11+
// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
12+
// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index
13+
// CHECK: } do {
14+
// CHECK: ^bb0(%[[VAL_8:.*]]: index):
15+
// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index
16+
// CHECK: %[[VAL_10:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32
17+
// CHECK: memref.store %[[VAL_10]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
18+
// CHECK: scf.yield %[[VAL_9]] : index
19+
// CHECK: }
20+
// CHECK: return
21+
// CHECK: }
22+
func @single_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
23+
%c0 = constant 0 : index
24+
%c1 = constant 1 : index
25+
scf.for %i = %c0 to %arg1 step %c1 {
26+
%0 = addi %arg2, %arg2 : i32
27+
memref.store %0, %arg0[%i] : memref<?xi32>
28+
}
29+
return
30+
}
31+
32+
// -----
33+
34+
// CHECK-LABEL: builtin.func @nested_loop(
35+
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xi32>,
36+
// CHECK-SAME: %[[VAL_1:.*]]: index,
37+
// CHECK-SAME: %[[VAL_2:.*]]: i32) {
38+
// CHECK: %[[VAL_3:.*]] = constant 0 : index
39+
// CHECK: %[[VAL_4:.*]] = constant 1 : index
40+
// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index {
41+
// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
42+
// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index
43+
// CHECK: } do {
44+
// CHECK: ^bb0(%[[VAL_8:.*]]: index):
45+
// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index
46+
// CHECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_3]]) : (index) -> index {
47+
// CHECK: %[[VAL_12:.*]] = cmpi slt, %[[VAL_11]], %[[VAL_1]] : index
48+
// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index
49+
// CHECK: } do {
50+
// CHECK: ^bb0(%[[VAL_13:.*]]: index):
51+
// CHECK: %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_4]] : index
52+
// CHECK: %[[VAL_15:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32
53+
// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
54+
// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_13]]] : memref<?xi32>
55+
// CHECK: scf.yield %[[VAL_14]] : index
56+
// CHECK: }
57+
// CHECK: scf.yield %[[VAL_9]] : index
58+
// CHECK: }
59+
// CHECK: return
60+
// CHECK: }
61+
func @nested_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
62+
%c0 = constant 0 : index
63+
%c1 = constant 1 : index
64+
scf.for %i = %c0 to %arg1 step %c1 {
65+
scf.for %j = %c0 to %arg1 step %c1 {
66+
%0 = addi %arg2, %arg2 : i32
67+
memref.store %0, %arg0[%i] : memref<?xi32>
68+
memref.store %0, %arg0[%j] : memref<?xi32>
69+
}
70+
}
71+
return
72+
}
73+
74+
// -----
75+
76+
// CHECK-LABEL: builtin.func @for_iter_args(
77+
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
78+
// CHECK-SAME: %[[VAL_2:.*]]: index) -> f32 {
79+
// CHECK: %[[VAL_3:.*]] = constant 0.000000e+00 : f32
80+
// CHECK: %[[VAL_4:.*]]:3 = scf.while (%[[VAL_5:.*]] = %[[VAL_0]], %[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_3]]) : (index, f32, f32) -> (index, f32, f32) {
81+
// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_5]], %[[VAL_1]] : index
82+
// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : index, f32, f32
83+
// CHECK: } do {
84+
// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
85+
// CHECK: %[[VAL_12:.*]] = addi %[[VAL_9]], %[[VAL_2]] : index
86+
// CHECK: %[[VAL_13:.*]] = addf %[[VAL_10]], %[[VAL_11]] : f32
87+
// CHECK: scf.yield %[[VAL_12]], %[[VAL_13]], %[[VAL_13]] : index, f32, f32
88+
// CHECK: }
89+
// CHECK: return %[[VAL_14:.*]]#2 : f32
90+
// CHECK: }
91+
func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 {
92+
%s0 = constant 0.0 : f32
93+
%result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) {
94+
%sn = addf %iarg0, %iarg1 : f32
95+
scf.yield %sn, %sn : f32, f32
96+
}
97+
return %result#1 : f32
98+
}
99+
100+
// -----
101+
102+
// CHECK-LABEL: builtin.func @exec_region_multiple_yields(
103+
// CHECK-SAME: %[[VAL_0:.*]]: i32,
104+
// CHECK-SAME: %[[VAL_1:.*]]: index,
105+
// CHECK-SAME: %[[VAL_2:.*]]: i32) -> i32 {
106+
// CHECK: %[[VAL_3:.*]] = constant 0 : index
107+
// CHECK: %[[VAL_4:.*]] = constant 1 : index
108+
// CHECK: %[[VAL_5:.*]]:2 = scf.while (%[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_0]]) : (index, i32) -> (index, i32) {
109+
// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
110+
// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_6]], %[[VAL_7]] : index, i32
111+
// CHECK: } do {
112+
// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32):
113+
// CHECK: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_4]] : index
114+
// CHECK: %[[VAL_12:.*]] = scf.execute_region -> i32 {
115+
// CHECK: %[[VAL_13:.*]] = cmpi slt, %[[VAL_9]], %[[VAL_4]] : index
116+
// CHECK: cond_br %[[VAL_13]], ^bb1, ^bb2
117+
// CHECK: ^bb1:
118+
// CHECK: %[[VAL_14:.*]] = subi %[[VAL_10]], %[[VAL_0]] : i32
119+
// CHECK: scf.yield %[[VAL_14]] : i32
120+
// CHECK: ^bb2:
121+
// CHECK: %[[VAL_15:.*]] = muli %[[VAL_10]], %[[VAL_2]] : i32
122+
// CHECK: scf.yield %[[VAL_15]] : i32
123+
// CHECK: }
124+
// CHECK: scf.yield %[[VAL_11]], %[[VAL_16:.*]] : index, i32
125+
// CHECK: }
126+
// CHECK: return %[[VAL_17:.*]]#1 : i32
127+
// CHECK: }
128+
func @exec_region_multiple_yields(%arg0: i32, %arg1: index, %arg2: i32) -> i32 {
129+
%c1_i32 = constant 1 : i32
130+
%c2_i32 = constant 2 : i32
131+
%c0 = constant 0 : index
132+
%c1 = constant 1 : index
133+
%c5 = constant 5 : index
134+
%0 = scf.for %i = %c0 to %arg1 step %c1 iter_args(%iarg0 = %arg0) -> i32 {
135+
%2 = scf.execute_region -> i32 {
136+
%1 = cmpi slt, %i, %c1 : index
137+
cond_br %1, ^bb1, ^bb2
138+
^bb1:
139+
%2 = subi %iarg0, %arg0 : i32
140+
scf.yield %2 : i32
141+
^bb2:
142+
%3 = muli %iarg0, %arg2 : i32
143+
scf.yield %3 : i32
144+
}
145+
scf.yield %2 : i32
146+
}
147+
return %0 : i32
148+
}

0 commit comments

Comments
 (0)