Skip to content

Commit f720150

Browse files
author
Jerry Wu
authored
[mlir] Add transformation to wrap scf::while in zero-trip-check (#81050)
Add `scf::wrapWhileLoopInZeroTripCheck` to wrap scf while loop in zero-trip-check.
1 parent 06c89bd commit f720150

File tree

7 files changed

+379
-0
lines changed

7 files changed

+379
-0
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace scf {
3030
class IfOp;
3131
class ForOp;
3232
class ParallelOp;
33+
class WhileOp;
3334

3435
/// Fuses all adjacent scf.parallel operations with identical bounds and step
3536
/// into one scf.parallel operations. Uses a naive aliasing and dependency
@@ -181,6 +182,46 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
181182
const PipeliningOption &options,
182183
bool *modifiedIR = nullptr);
183184

185+
/// Create zero-trip-check around a `while` op and return the new loop op in the
186+
/// check. The while loop is rotated to avoid evaluating the condition twice
187+
///
188+
/// By default the check won't be created for do-while loop as it is not
189+
/// required. `forceCreateCheck` can force the creation.
190+
///
191+
/// It turns:
192+
///
193+
/// scf.while (%arg0 = %init) : (i32) -> i64 {
194+
/// %val = .., %arg0 : i64
195+
/// %cond = arith.cmpi .., %arg0 : i32
196+
/// scf.condition(%cond) %val : i64
197+
/// } do {
198+
/// ^bb0(%arg1: i64):
199+
/// %next = .., %arg1 : i32
200+
/// scf.yield %next : i32
201+
/// }
202+
///
203+
/// into:
204+
///
205+
/// %pre_val = .., %init : i64
206+
/// %pre_cond = arith.cmpi .., %init : i32
207+
/// scf.if %pre_cond -> i64 {
208+
/// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
209+
/// %next = .., %arg1 : i32
210+
/// %val = .., %next : i64
211+
/// %cond = arith.cmpi .., %next : i32
212+
/// scf.condition(%cond) %val : i64
213+
/// } do {
214+
/// ^bb0(%arg2: i64):
215+
/// %scf.yield %arg2 : i32
216+
/// }
217+
/// scf.yield %res : i64
218+
/// } else {
219+
/// scf.yield %pre_val : i64
220+
/// }
221+
FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
222+
RewriterBase &rewriter,
223+
bool forceCreateCheck = false);
224+
184225
} // namespace scf
185226
} // namespace mlir
186227

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
1313
ParallelLoopTiling.cpp
1414
StructuralTypeConversions.cpp
1515
TileUsingInterface.cpp
16+
WrapInZeroTripCheck.cpp
1617

1718
ADDITIONAL_HEADER_DIRS
1819
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SCF/IR/SCF.h"
10+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
11+
#include "mlir/IR/IRMapping.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
14+
using namespace mlir;
15+
16+
/// Create zero-trip-check around a `while` op and return the new loop op in the
17+
/// check. The while loop is rotated to avoid evaluating the condition twice.
18+
///
19+
/// Given an example below:
20+
///
21+
/// scf.while (%arg0 = %init) : (i32) -> i64 {
22+
/// %val = .., %arg0 : i64
23+
/// %cond = arith.cmpi .., %arg0 : i32
24+
/// scf.condition(%cond) %val : i64
25+
/// } do {
26+
/// ^bb0(%arg1: i64):
27+
/// %next = .., %arg1 : i32
28+
/// scf.yield %next : i32
29+
/// }
30+
///
31+
/// First clone before block to the front of the loop:
32+
///
33+
/// %pre_val = .., %init : i64
34+
/// %pre_cond = arith.cmpi .., %init : i32
35+
/// scf.while (%arg0 = %init) : (i32) -> i64 {
36+
/// %val = .., %arg0 : i64
37+
/// %cond = arith.cmpi .., %arg0 : i32
38+
/// scf.condition(%cond) %val : i64
39+
/// } do {
40+
/// ^bb0(%arg1: i64):
41+
/// %next = .., %arg1 : i32
42+
/// scf.yield %next : i32
43+
/// }
44+
///
45+
/// Create `if` op with the condition, rotate and move the loop into the else
46+
/// branch:
47+
///
48+
/// %pre_val = .., %init : i64
49+
/// %pre_cond = arith.cmpi .., %init : i32
50+
/// scf.if %pre_cond -> i64 {
51+
/// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
52+
/// // Original after block
53+
/// %next = .., %arg1 : i32
54+
/// // Original before block
55+
/// %val = .., %next : i64
56+
/// %cond = arith.cmpi .., %next : i32
57+
/// scf.condition(%cond) %val : i64
58+
/// } do {
59+
/// ^bb0(%arg2: i64):
60+
/// %scf.yield %arg2 : i32
61+
/// }
62+
/// scf.yield %res : i64
63+
/// } else {
64+
/// scf.yield %pre_val : i64
65+
/// }
66+
FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
67+
scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) {
68+
// If the loop is in do-while form (after block only passes through values),
69+
// there is no need to create a zero-trip-check as before block is always run.
70+
if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) {
71+
return whileOp;
72+
}
73+
74+
OpBuilder::InsertionGuard insertion_guard(rewriter);
75+
76+
IRMapping mapper;
77+
Block *beforeBlock = whileOp.getBeforeBody();
78+
// Clone before block before the loop for zero-trip-check.
79+
for (auto [arg, init] :
80+
llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
81+
mapper.map(arg, init);
82+
}
83+
rewriter.setInsertionPoint(whileOp);
84+
for (auto &op : *beforeBlock) {
85+
if (isa<scf::ConditionOp>(op)) {
86+
break;
87+
}
88+
// Safe to clone everything as in a single block all defs have been cloned
89+
// and added to mapper in order.
90+
rewriter.insert(op.clone(mapper));
91+
}
92+
93+
scf::ConditionOp condOp = whileOp.getConditionOp();
94+
Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
95+
SmallVector<Value> clonedCondArgs = llvm::map_to_vector(
96+
condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
97+
98+
// Create rotated while loop.
99+
auto newLoopOp = rewriter.create<scf::WhileOp>(
100+
whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
101+
[&](OpBuilder &builder, Location loc, ValueRange args) {
102+
// Rotate and move the loop body into before block.
103+
auto newBlock = builder.getBlock();
104+
rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
105+
auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
106+
rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
107+
yieldOp.getResults());
108+
rewriter.eraseOp(yieldOp);
109+
},
110+
[&](OpBuilder &builder, Location loc, ValueRange args) {
111+
// Pass through values.
112+
builder.create<scf::YieldOp>(loc, args);
113+
});
114+
115+
// Create zero-trip-check and move the while loop in.
116+
auto ifOp = rewriter.create<scf::IfOp>(
117+
whileOp.getLoc(), clonedCondition,
118+
[&](OpBuilder &builder, Location loc) {
119+
// Then runs the while loop.
120+
rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(),
121+
builder.getInsertionPoint());
122+
builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
123+
},
124+
[&](OpBuilder &builder, Location loc) {
125+
// Else returns the results from precondition.
126+
builder.create<scf::YieldOp>(loc, clonedCondArgs);
127+
});
128+
129+
rewriter.replaceOp(whileOp, ifOp);
130+
131+
return newLoopOp;
132+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check='force-create-check=true' -split-input-file | FileCheck %s --check-prefix FORCE-CREATE-CHECK
3+
4+
func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
5+
%cst0 = arith.constant 0 : i32
6+
%cst5 = arith.constant 5 : i32
7+
%res:2 = scf.while (%iter = %cst0) : (i32) -> (i32, i32) {
8+
%cond = arith.cmpi slt, %iter, %bound : i32
9+
%inv = arith.addi %bound, %cst5 : i32
10+
scf.condition(%cond) %iter, %inv : i32, i32
11+
} do {
12+
^bb0(%arg1: i32, %arg2: i32):
13+
%next = arith.addi %arg1, %arg2 : i32
14+
scf.yield %next : i32
15+
}
16+
return %res#0 : i32
17+
}
18+
19+
// CHECK-LABEL: func.func @wrap_while_loop_in_zero_trip_check(
20+
// CHECK-SAME: %[[BOUND:.*]]: i32) -> i32 {
21+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
22+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
23+
// CHECK-DAG: %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[BOUND]] : i32
24+
// CHECK-DAG: %[[PRE_INV:.*]] = arith.addi %[[BOUND]], %[[C5]] : i32
25+
// CHECK: %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
26+
// CHECK: %[[WHILE:.*]]:2 = scf.while (
27+
// CHECK-SAME: %[[ARG1:.*]] = %[[C0]], %[[ARG2:.*]] = %[[PRE_INV]]
28+
// CHECK-SAME: ) : (i32, i32) -> (i32, i32) {
29+
// CHECK: %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
30+
// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[BOUND]] : i32
31+
// CHECK: %[[INV:.*]] = arith.addi %[[BOUND]], %[[C5]] : i32
32+
// CHECK: scf.condition(%[[COND]]) %[[NEXT]], %[[INV]] : i32, i32
33+
// CHECK: } do {
34+
// CHECK: ^bb0(%[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
35+
// CHECK: scf.yield %[[ARG3]], %[[ARG4]] : i32, i32
36+
// CHECK: }
37+
// CHECK: scf.yield %[[WHILE]]#0, %[[WHILE]]#1 : i32, i32
38+
// CHECK: } else {
39+
// CHECK: scf.yield %[[C0]], %[[PRE_INV]] : i32, i32
40+
// CHECK: }
41+
// CHECK: return %[[IF]]#0 : i32
42+
43+
// -----
44+
45+
func.func @wrap_while_loop_with_minimal_before_block(%bound : i32) -> i32 {
46+
%cst0 = arith.constant 0 : i32
47+
%true = arith.constant true
48+
%cst5 = arith.constant 5 : i32
49+
%res = scf.while (%iter = %cst0, %arg0 = %true) : (i32, i1) -> i32 {
50+
scf.condition(%arg0) %iter : i32
51+
} do {
52+
^bb0(%arg1: i32):
53+
%next = arith.addi %arg1, %cst5 : i32
54+
%cond = arith.cmpi slt, %next, %bound : i32
55+
scf.yield %next, %cond : i32, i1
56+
}
57+
return %res : i32
58+
}
59+
60+
// CHECK-LABEL: func.func @wrap_while_loop_with_minimal_before_block(
61+
// CHECK-SAME: %[[BOUND:.*]]: i32) -> i32 {
62+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
63+
// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
64+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
65+
// CHECK: %[[IF:.*]] = scf.if %[[TRUE]] -> (i32) {
66+
// CHECK: %[[WHILE:.*]] = scf.while (%[[ARG1:.*]] = %[[C0]]) : (i32) -> i32 {
67+
// CHECK: %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[C5]] : i32
68+
// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[BOUND]] : i32
69+
// CHECK: scf.condition(%[[COND]]) %[[NEXT]] : i32
70+
// CHECK: } do {
71+
// CHECK: ^bb0(%[[ARG2:.*]]: i32):
72+
// CHECK: scf.yield %[[ARG2]] : i32
73+
// CHECK: }
74+
// CHECK: scf.yield %[[WHILE]] : i32
75+
// CHECK: } else {
76+
// CHECK: scf.yield %[[C0]] : i32
77+
// CHECK: }
78+
// CHECK: return %[[IF]] : i32
79+
80+
// -----
81+
82+
func.func @wrap_do_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
83+
%cst0 = arith.constant 0 : i32
84+
%cst5 = arith.constant 5 : i32
85+
%res = scf.while (%iter = %cst0) : (i32) -> i32 {
86+
%next = arith.addi %iter, %cst5 : i32
87+
%cond = arith.cmpi slt, %next, %bound : i32
88+
scf.condition(%cond) %next : i32
89+
} do {
90+
^bb0(%arg1: i32):
91+
scf.yield %arg1 : i32
92+
}
93+
return %res : i32
94+
}
95+
96+
// CHECK-LABEL: func.func @wrap_do_while_loop_in_zero_trip_check(
97+
// CHECK-SAME: %[[BOUND:.*]]: i32) -> i32 {
98+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
99+
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
100+
// CHECK-NOT: scf.if
101+
// CHECK: %[[WHILE:.*]] = scf.while (%[[ARG1:.*]] = %[[C0]]) : (i32) -> i32 {
102+
// CHECK: %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[C5]] : i32
103+
// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[BOUND]] : i32
104+
// CHECK: scf.condition(%[[COND]]) %[[NEXT]] : i32
105+
// CHECK: } do {
106+
// CHECK: ^bb0(%[[ARG2:.*]]: i32):
107+
// CHECK: scf.yield %[[ARG2]] : i32
108+
// CHECK: }
109+
// CHECK: return %[[WHILE]] : i32
110+
111+
// FORCE-CREATE-CHECK-LABEL: func.func @wrap_do_while_loop_in_zero_trip_check(
112+
// FORCE-CREATE-CHECK-SAME: %[[BOUND:.*]]: i32) -> i32 {
113+
// FORCE-CREATE-CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
114+
// FORCE-CREATE-CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
115+
// FORCE-CREATE-CHECK: %[[PRE_NEXT:.*]] = arith.addi %[[C0]], %[[C5]] : i32
116+
// FORCE-CREATE-CHECK: %[[PRE_COND:.*]] = arith.cmpi slt, %[[PRE_NEXT]], %[[BOUND]] : i32
117+
// FORCE-CREATE-CHECK: %[[IF:.*]] = scf.if %[[PRE_COND]] -> (i32) {
118+
// FORCE-CREATE-CHECK: %[[WHILE:.*]] = scf.while (%[[ARG1:.*]] = %[[PRE_NEXT]]) : (i32) -> i32 {
119+
// FORCE-CREATE-CHECK: %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[C5]] : i32
120+
// FORCE-CREATE-CHECK: %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[BOUND]] : i32
121+
// FORCE-CREATE-CHECK: scf.condition(%[[COND]]) %[[NEXT]] : i32
122+
// FORCE-CREATE-CHECK: } do {
123+
// FORCE-CREATE-CHECK: ^bb0(%[[ARG2:.*]]: i32):
124+
// FORCE-CREATE-CHECK: scf.yield %[[ARG2]] : i32
125+
// FORCE-CREATE-CHECK: }
126+
// FORCE-CREATE-CHECK: scf.yield %[[WHILE]] : i32
127+
// FORCE-CREATE-CHECK: } else {
128+
// FORCE-CREATE-CHECK: scf.yield %[[PRE_NEXT]] : i32
129+
// FORCE-CREATE-CHECK: }
130+
// FORCE-CREATE-CHECK: return %[[IF]] : i32

mlir/test/lib/Dialect/SCF/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
33
TestLoopParametricTiling.cpp
44
TestLoopUnrolling.cpp
55
TestSCFUtils.cpp
6+
TestSCFWrapInZeroTripCheck.cpp
67
TestWhileOpBuilder.cpp
78

89
EXCLUDE_FROM_LIBMLIR

0 commit comments

Comments
 (0)