12
12
#include " mlir/Interfaces/ValueBoundsOpInterface.h"
13
13
14
14
using namespace mlir ;
15
+ using presburger::BoundType;
15
16
16
17
namespace mlir {
17
18
namespace scf {
@@ -20,28 +21,7 @@ namespace {
20
21
struct ForOpInterface
21
22
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
22
23
23
- // / Populate bounds of values/dimensions for iter_args/OpResults. If the
24
- // / value/dimension size does not change in an iteration, we can deduce that
25
- // / it the same as the initial value/dimension.
26
- // /
27
- // / Example 1:
28
- // / %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29
- // / ...
30
- // / %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31
- // / scf.yield %1 : tensor<?xf32>
32
- // / }
33
- // / --> bound(%0)[0] == bound(%t)[0]
34
- // / --> bound(%arg0)[0] == bound(%t)[0]
35
- // /
36
- // / Example 2:
37
- // / %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38
- // / %sz = tensor.dim %arg0 : tensor<?xf32>
39
- // / %incr = arith.addi %sz, %c1 : index
40
- // / %1 = tensor.empty(%incr) : tensor<?xf32>
41
- // / scf.yield %1 : tensor<?xf32>
42
- // / }
43
- // / --> The yielded tensor dimension size changes with each iteration. Such
44
- // / loops are not supported and no constraints are added.
24
+ // / Populate bounds of values/dimensions for iter_args/OpResults.
45
25
static void populateIterArgBounds (scf::ForOp forOp, Value value,
46
26
std::optional<int64_t > dim,
47
27
ValueBoundsConstraintSet &cstr) {
@@ -53,31 +33,59 @@ struct ForOpInterface
53
33
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber ();
54
34
}
55
35
36
+ // An EQ constraint can be added if the yielded value (dimension size)
37
+ // equals the corresponding block argument (dimension size).
56
38
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ())
57
39
.getOperand (iterArgIdx);
58
40
Value iterArg = forOp.getRegionIterArg (iterArgIdx);
59
41
Value initArg = forOp.getInitArgs ()[iterArgIdx];
60
42
61
- // Populate constraints for the yielded value.
62
- cstr.populateConstraints (yieldedValue, dim);
63
- // Populate constraints for the iter_arg. This is just to ensure that the
64
- // iter_arg is mapped in the constraint set, which is a prerequisite for
65
- // `compare`. It may lead to a recursive call to this function in case the
66
- // iter_arg was not visited when the constraints for the yielded value were
67
- // populated, but no additional work is done.
68
- cstr.populateConstraints (iterArg, dim);
69
-
70
- // An EQ constraint can be added if the yielded value (dimension size)
71
- // equals the corresponding block argument (dimension size).
72
- if (cstr.compare (yieldedValue, dim,
73
- ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
74
- dim)) {
43
+ auto addEqBound = [&]() {
75
44
if (dim.has_value ()) {
76
45
cstr.bound (value)[*dim] == cstr.getExpr (initArg, dim);
77
46
} else {
78
47
cstr.bound (value) == initArg;
79
48
}
49
+ };
50
+
51
+ if (yieldedValue == iterArg) {
52
+ addEqBound ();
53
+ return ;
54
+ }
55
+
56
+ // Compute EQ bound for yielded value.
57
+ AffineMap bound;
58
+ ValueDimList boundOperands;
59
+ LogicalResult status = ValueBoundsConstraintSet::computeBound (
60
+ bound, boundOperands, BoundType::EQ, yieldedValue, dim,
61
+ [&](Value v, std::optional<int64_t > d, ValueBoundsConstraintSet &cstr) {
62
+ // Stop when reaching a block argument of the loop body.
63
+ if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
64
+ return bbArg.getOwner ()->getParentOp () == forOp;
65
+ // Stop when reaching a value that is defined outside of the loop. It
66
+ // is impossible to reach an iter_arg from there.
67
+ Operation *op = v.getDefiningOp ();
68
+ return forOp.getRegion ().findAncestorOpInRegion (*op) == nullptr ;
69
+ });
70
+ if (failed (status))
71
+ return ;
72
+ if (bound.getNumResults () != 1 )
73
+ return ;
74
+
75
+ // Check if computed bound equals the corresponding iter_arg.
76
+ Value singleValue = nullptr ;
77
+ std::optional<int64_t > singleDim;
78
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult (0 ))) {
79
+ int64_t idx = dimExpr.getPosition ();
80
+ singleValue = boundOperands[idx].first ;
81
+ singleDim = boundOperands[idx].second ;
82
+ } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult (0 ))) {
83
+ int64_t idx = symExpr.getPosition () + bound.getNumDims ();
84
+ singleValue = boundOperands[idx].first ;
85
+ singleDim = boundOperands[idx].second ;
80
86
}
87
+ if (singleValue == iterArg && singleDim == dim)
88
+ addEqBound ();
81
89
}
82
90
83
91
void populateBoundsForIndexValue (Operation *op, Value value,
0 commit comments