Skip to content

Commit e029a5b

Browse files
committed
[mlir][scf]: Add value bound for the computed upper bound of for loop
Add additional bound for the induction variable of the `scf.for` such that: `%iv <= %lower_bound + (%trip_count - 1) * step`
1 parent 70906f0 commit e029a5b

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ namespace {
2020
struct ForOpInterface
2121
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
2222

23+
static AffineExpr getTripCountExpr(scf::ForOp forOp,
24+
ValueBoundsConstraintSet &cstr) {
25+
AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound());
26+
AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound());
27+
AffineExpr stepExpr = cstr.getExpr(forOp.getStep());
28+
AffineExpr tripCountExpr =
29+
AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step
30+
return tripCountExpr;
31+
}
32+
2333
/// Populate bounds of values/dimensions for iter_args/OpResults. If the
2434
/// value/dimension size does not change in an iteration, we can deduce that
2535
/// it the same as the initial value/dimension.
@@ -77,11 +87,7 @@ struct ForOpInterface
7787
// `value` is result of `forOp`, we can prove that:
7888
// %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
7989
// Where trip_count is (ub - lb) / step.
80-
AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound());
81-
AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound());
82-
AffineExpr stepExpr = cstr.getExpr(forOp.getStep());
83-
AffineExpr tripCountExpr =
84-
AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step
90+
AffineExpr tripCountExpr = getTripCountExpr(forOp, cstr);
8591
AffineExpr oneIterAdvanceExpr =
8692
cstr.getExpr(yieldedValue) - cstr.getExpr(iterArg);
8793
cstr.bound(value) ==
@@ -93,9 +99,18 @@ struct ForOpInterface
9399
auto forOp = cast<ForOp>(op);
94100

95101
if (value == forOp.getInductionVar()) {
96-
// TODO: Take into account step size.
97102
cstr.bound(value) >= forOp.getLowerBound();
98103
cstr.bound(value) < forOp.getUpperBound();
104+
// iv <= lb + ((ub-lb)/step - 1) * step
105+
// This bound does not replace the `iv < ub` constraint mentioned above,
106+
// since constraints involving the multiplication of two two constraint
107+
// set dimensions are not supported.
108+
AffineExpr tripCountMinusOne =
109+
getTripCountExpr(forOp, cstr) - cstr.getExpr(1);
110+
AffineExpr computedUpperBound =
111+
cstr.getExpr(forOp.getLowerBound()) +
112+
AffineExpr(tripCountMinusOne * cstr.getExpr(forOp.getStep()));
113+
cstr.bound(value) <= computedUpperBound;
99114
return;
100115
}
101116

mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,47 @@ func.func @compare_scf_for(%a: index, %b: index, %c: index) {
270270

271271
// -----
272272

273+
func.func @scf_for_induction_var_upper_bound() {
274+
%c0 = arith.constant 0 : index
275+
%c1 = arith.constant 1 : index
276+
%c2 = arith.constant 2 : index
277+
%c3 = arith.constant 3 : index
278+
%c4 = arith.constant 4 : index
279+
%c5 = arith.constant 5 : index
280+
%c8 = arith.constant 8 : index
281+
%c10 = arith.constant 10 : index
282+
scf.for %iv = %c0 to %c10 step %c4 {
283+
// expected-remark @below{{true}}
284+
"test.compare"(%iv, %c8) {cmp = "LE"} : (index, index) -> ()
285+
}
286+
scf.for %iv = %c2 to %c8 step %c3 {
287+
// expected-remark @below{{true}}
288+
"test.compare"(%iv, %c5) {cmp = "LE"} : (index, index) -> ()
289+
}
290+
return
291+
}
292+
293+
// -----
294+
295+
#map_ceildiv_dynamic_divisor = affine_map<(i)[s] -> (i ceildiv s)>
296+
func.func @scf_for_induction_var_computed_upper_bound(%upperBound: index, %step: index) {
297+
%c0 = arith.constant 0 : index
298+
%c1 = arith.constant 1 : index
299+
%tripCount = affine.apply #map_ceildiv_dynamic_divisor (%upperBound)[%step]
300+
%tripCountMinusOne = arith.subi %tripCount, %c1 : index
301+
%computedUpperBound = arith.muli %tripCountMinusOne, %step : index
302+
scf.for %iv = %c0 to %upperBound step %step {
303+
// TODO: Value bounds analysis will fail to compute upper bound
304+
// because multiplication/division of unknown block arguments is
305+
// not supported.
306+
// expected-error @below{{unknown}}
307+
"test.compare"(%iv, %computedUpperBound) {cmp = "LE"} : (index, index) -> ()
308+
}
309+
return
310+
}
311+
312+
// -----
313+
273314
func.func @scf_for_result_infer() {
274315
%c0 = arith.constant 0 : index
275316
%c1 = arith.constant 1 : index

0 commit comments

Comments
 (0)