@@ -20,6 +20,16 @@ namespace {
20
20
struct ForOpInterface
21
21
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
22
22
23
+ static AffineExpr getTripCountExpr (scf::ForOp forOp,
24
+ ValueBoundsConstraintSet &cstr) {
25
+ AffineExpr lbExpr = cstr.getExpr (forOp.getLowerBound ());
26
+ AffineExpr ubExpr = cstr.getExpr (forOp.getUpperBound ());
27
+ AffineExpr stepExpr = cstr.getExpr (forOp.getStep ());
28
+ AffineExpr tripCountExpr =
29
+ AffineExpr (ubExpr - lbExpr).ceilDiv (stepExpr); // (ub - lb) / step
30
+ return tripCountExpr;
31
+ }
32
+
23
33
// / Populate bounds of values/dimensions for iter_args/OpResults. If the
24
34
// / value/dimension size does not change in an iteration, we can deduce that
25
35
// / it the same as the initial value/dimension.
@@ -77,11 +87,7 @@ struct ForOpInterface
77
87
// `value` is result of `forOp`, we can prove that:
78
88
// %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
79
89
// Where trip_count is (ub - lb) / step.
80
- AffineExpr lbExpr = cstr.getExpr (forOp.getLowerBound ());
81
- AffineExpr ubExpr = cstr.getExpr (forOp.getUpperBound ());
82
- AffineExpr stepExpr = cstr.getExpr (forOp.getStep ());
83
- AffineExpr tripCountExpr =
84
- AffineExpr (ubExpr - lbExpr).ceilDiv (stepExpr); // (ub - lb) / step
90
+ AffineExpr tripCountExpr = getTripCountExpr (forOp, cstr);
85
91
AffineExpr oneIterAdvanceExpr =
86
92
cstr.getExpr (yieldedValue) - cstr.getExpr (iterArg);
87
93
cstr.bound (value) ==
@@ -93,9 +99,18 @@ struct ForOpInterface
93
99
auto forOp = cast<ForOp>(op);
94
100
95
101
if (value == forOp.getInductionVar ()) {
96
- // TODO: Take into account step size.
97
102
cstr.bound (value) >= forOp.getLowerBound ();
98
103
cstr.bound (value) < forOp.getUpperBound ();
104
+ // iv <= lb + ((ub-lb)/step - 1) * step
105
+ // This bound does not replace the `iv < ub` constraint mentioned above,
106
+ // since constraints involving the multiplication of two two constraint
107
+ // set dimensions are not supported.
108
+ AffineExpr tripCountMinusOne =
109
+ getTripCountExpr (forOp, cstr) - cstr.getExpr (1 );
110
+ AffineExpr computedUpperBound =
111
+ cstr.getExpr (forOp.getLowerBound ()) +
112
+ AffineExpr (tripCountMinusOne * cstr.getExpr (forOp.getStep ()));
113
+ cstr.bound (value) <= computedUpperBound;
99
114
return ;
100
115
}
101
116
0 commit comments