Skip to content

Commit e5e1bc0

Browse files
committed
Revert "[mlir][SCF][NFC] ValueBoundsConstraintSet: Simplify scf.for implementation (#86239)"
This reverts commit 24e4429. gcc7 bot is broken
1 parent f2d8218 commit e5e1bc0

File tree

1 file changed

+44
-36
lines changed

1 file changed

+44
-36
lines changed

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1313

1414
using namespace mlir;
15+
using presburger::BoundType;
1516

1617
namespace mlir {
1718
namespace scf {
@@ -20,28 +21,7 @@ namespace {
2021
struct ForOpInterface
2122
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
2223

23-
/// Populate bounds of values/dimensions for iter_args/OpResults. If the
24-
/// value/dimension size does not change in an iteration, we can deduce that
25-
/// it the same as the initial value/dimension.
26-
///
27-
/// Example 1:
28-
/// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29-
/// ...
30-
/// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31-
/// scf.yield %1 : tensor<?xf32>
32-
/// }
33-
/// --> bound(%0)[0] == bound(%t)[0]
34-
/// --> bound(%arg0)[0] == bound(%t)[0]
35-
///
36-
/// Example 2:
37-
/// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38-
/// %sz = tensor.dim %arg0 : tensor<?xf32>
39-
/// %incr = arith.addi %sz, %c1 : index
40-
/// %1 = tensor.empty(%incr) : tensor<?xf32>
41-
/// scf.yield %1 : tensor<?xf32>
42-
/// }
43-
/// --> The yielded tensor dimension size changes with each iteration. Such
44-
/// loops are not supported and no constraints are added.
24+
/// Populate bounds of values/dimensions for iter_args/OpResults.
4525
static void populateIterArgBounds(scf::ForOp forOp, Value value,
4626
std::optional<int64_t> dim,
4727
ValueBoundsConstraintSet &cstr) {
@@ -53,31 +33,59 @@ struct ForOpInterface
5333
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
5434
}
5535

36+
// An EQ constraint can be added if the yielded value (dimension size)
37+
// equals the corresponding block argument (dimension size).
5638
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
5739
.getOperand(iterArgIdx);
5840
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
5941
Value initArg = forOp.getInitArgs()[iterArgIdx];
6042

61-
// Populate constraints for the yielded value.
62-
cstr.populateConstraints(yieldedValue, dim);
63-
// Populate constraints for the iter_arg. This is just to ensure that the
64-
// iter_arg is mapped in the constraint set, which is a prerequisite for
65-
// `compare`. It may lead to a recursive call to this function in case the
66-
// iter_arg was not visited when the constraints for the yielded value were
67-
// populated, but no additional work is done.
68-
cstr.populateConstraints(iterArg, dim);
69-
70-
// An EQ constraint can be added if the yielded value (dimension size)
71-
// equals the corresponding block argument (dimension size).
72-
if (cstr.compare(yieldedValue, dim,
73-
ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
74-
dim)) {
43+
auto addEqBound = [&]() {
7544
if (dim.has_value()) {
7645
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
7746
} else {
7847
cstr.bound(value) == initArg;
7948
}
49+
};
50+
51+
if (yieldedValue == iterArg) {
52+
addEqBound();
53+
return;
54+
}
55+
56+
// Compute EQ bound for yielded value.
57+
AffineMap bound;
58+
ValueDimList boundOperands;
59+
LogicalResult status = ValueBoundsConstraintSet::computeBound(
60+
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
61+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
62+
// Stop when reaching a block argument of the loop body.
63+
if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
64+
return bbArg.getOwner()->getParentOp() == forOp;
65+
// Stop when reaching a value that is defined outside of the loop. It
66+
// is impossible to reach an iter_arg from there.
67+
Operation *op = v.getDefiningOp();
68+
return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
69+
});
70+
if (failed(status))
71+
return;
72+
if (bound.getNumResults() != 1)
73+
return;
74+
75+
// Check if computed bound equals the corresponding iter_arg.
76+
Value singleValue = nullptr;
77+
std::optional<int64_t> singleDim;
78+
if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) {
79+
int64_t idx = dimExpr.getPosition();
80+
singleValue = boundOperands[idx].first;
81+
singleDim = boundOperands[idx].second;
82+
} else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) {
83+
int64_t idx = symExpr.getPosition() + bound.getNumDims();
84+
singleValue = boundOperands[idx].first;
85+
singleDim = boundOperands[idx].second;
8086
}
87+
if (singleValue == iterArg && singleDim == dim)
88+
addEqBound();
8189
}
8290

8391
void populateBoundsForIndexValue(Operation *op, Value value,

0 commit comments

Comments
 (0)