Skip to content

Commit 179b27b

Browse files
committed
Keep changes within FlatLinearValueConstraints
1 parent 209a78c commit 179b27b

File tree

6 files changed

+114
-117
lines changed

6 files changed

+114
-117
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,20 +454,6 @@ class IntegerRelation {
454454
addLocalFloorDiv(getMPIntVec(dividend), MPInt(divisor));
455455
}
456456

457-
/// Adds a new local variable as the mod of an affine function of other
458-
/// variables. The coefficients of the operands of the mod are provided in
459-
/// `lhs` and `rhs` respectively. Three constraints are added to provide a
460-
/// conservative bound for the mod:
461-
/// 1. rhs > 0 (assumption/precondition)
462-
/// 2. lhs % rhs < rhs
463-
/// 3. lhs % rhs >= 0
464-
/// We ensure the rhs is positive so we can assume the result is positive.
465-
void addLocalModConservativeBounds(ArrayRef<MPInt> lhs, ArrayRef<MPInt> rhs);
466-
void addLocalModConservativeBounds(ArrayRef<int64_t> lhs,
467-
ArrayRef<int64_t> rhs) {
468-
addLocalModConservativeBounds(getMPIntVec(lhs), getMPIntVec(rhs));
469-
}
470-
471457
/// Projects out (aka eliminates) `num` variables starting at position
472458
/// `pos`. The resulting constraint system is the shadow along the dimensions
473459
/// that still exist. This method may not always be integer exact.

mlir/include/mlir/IR/AffineExprVisitor.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,9 @@ class SimpleAffineExprFlattener
413413
/// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
414414
/// symbolic rhs expression. `localExpr` is the simplified tree expression
415415
/// (AffineExpr) corresponding to the quantifier.
416-
virtual void addLocalIdSemiAffine(AffineExpr localExpr, ArrayRef<int64_t> lhs,
417-
ArrayRef<int64_t> rhs);
416+
virtual LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
417+
ArrayRef<int64_t> rhs,
418+
AffineExpr localExpr);
418419

419420
private:
420421
/// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
@@ -423,10 +424,11 @@ class SimpleAffineExprFlattener
423424
/// quantifier is already present, we put the coefficient in the proper index
424425
/// of `result`, otherwise we add a new local variable and put the coefficient
425426
/// there.
426-
void addLocalVariableSemiAffine(AffineExpr expr, ArrayRef<int64_t> lhs,
427-
ArrayRef<int64_t> rhs,
428-
SmallVectorImpl<int64_t> &result,
429-
unsigned long resultSize);
427+
LogicalResult addLocalVariableSemiAffine(AffineExpr expr,
428+
ArrayRef<int64_t> lhs,
429+
ArrayRef<int64_t> rhs,
430+
SmallVectorImpl<int64_t> &result,
431+
unsigned long resultSize);
430432

431433
// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
432434
// A floordiv is thus flattened by introducing a new local variable q, and

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 94 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,20 @@ using namespace presburger;
3636
namespace {
3737

3838
// See comments for SimpleAffineExprFlattener.
39-
// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
40-
// constraint information associated with mod's, floordiv's, and ceildiv's
41-
// in FlatLinearConstraints 'localVarCst'.
42-
struct AffineExprFlattener : public SimpleAffineExprFlattener {
43-
public:
39+
// An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by
40+
// recording constraint information associated with mod's, floordiv's, and
41+
// ceildiv's in FlatLinearConstraints 'localVarCst'.
42+
struct AffineExprFlattenerWithLocalVars : public SimpleAffineExprFlattener {
43+
using SimpleAffineExprFlattener::SimpleAffineExprFlattener;
44+
4445
// Constraints connecting newly introduced local variables (for mod's and
4546
// div's) to existing (dimensional and symbolic) ones. These are always
4647
// inequalities.
4748
IntegerPolyhedron localVarCst;
4849

49-
AffineExprFlattener(unsigned nDims, unsigned nSymbols,
50-
bool addConservativeSemiAffineBounds = false)
50+
AffineExprFlattenerWithLocalVars(unsigned nDims, unsigned nSymbols)
5151
: SimpleAffineExprFlattener(nDims, nSymbols),
52-
localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)),
53-
addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {}
54-
55-
bool hasUnhandledSemiAffineExpressions() const {
56-
return unhandledSemiAffineExpressions;
57-
}
52+
localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {};
5853

5954
private:
6055
// Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
@@ -70,30 +65,71 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
7065
localVarCst.addLocalFloorDiv(dividend, divisor);
7166
}
7267

73-
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
74-
// expr) when the rhs is a symbolic expression. The local identifier added
75-
// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
76-
// function of other identifiers, coefficients of which are specified in the
77-
// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
78-
// symbolic rhs expression. `localExpr` is the simplified tree expression
79-
// (AffineExpr) corresponding to the quantifier.
80-
void addLocalIdSemiAffine(AffineExpr localExpr, ArrayRef<int64_t> lhs,
81-
ArrayRef<int64_t> rhs) override {
82-
SimpleAffineExprFlattener::addLocalIdSemiAffine(localExpr, lhs, rhs);
83-
if (!addConservativeSemiAffineBounds) {
84-
unhandledSemiAffineExpressions = true;
85-
return;
86-
}
68+
// Semi-affine expressions are not supported by all flatteners.
69+
LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
70+
ArrayRef<int64_t> rhs,
71+
AffineExpr localExpr) override = 0;
72+
};
73+
74+
// An AffineExprFlattener is an AffineExprFlattenerWithLocalVars that explicitly
75+
// disallows semi-affine expressions. Flattening will fail if a semi-affine
76+
// expression is encountered.
77+
struct AffineExprFlattener : public AffineExprFlattenerWithLocalVars {
78+
using AffineExprFlattenerWithLocalVars::AffineExprFlattenerWithLocalVars;
79+
80+
LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
81+
ArrayRef<int64_t> rhs,
82+
AffineExpr localExpr) override {
83+
// AffineExprFlattener does not support semi-affine expressions.
84+
return failure();
85+
}
86+
};
87+
88+
// A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds
89+
// conservative bounds for semi-affine expressions (given assumptions hold). If
90+
// the assumptions required to add the semi-affine bounds are found not to hold
91+
// the final constraints set will be empty/inconsistent. If the assumptions are
92+
// never contradicted the final bounds still only will be correct if the
93+
// assumptions hold.
94+
struct SemiAffineExprFlattener : public AffineExprFlattenerWithLocalVars {
95+
using AffineExprFlattenerWithLocalVars::AffineExprFlattenerWithLocalVars;
96+
97+
LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
98+
ArrayRef<int64_t> rhs,
99+
AffineExpr localExpr) override {
100+
auto result =
101+
SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr);
102+
assert(succeeded(result) &&
103+
"unexpected failure in SimpleAffineExprFlattener");
104+
(void)result;
105+
87106
if (localExpr.getKind() == AffineExprKind::Mod) {
88-
localVarCst.addLocalModConservativeBounds(lhs, rhs);
89-
return;
107+
localVarCst.appendVar(VarKind::Local);
108+
// Add a conservative bound for `mod` assuming the rhs is > 0.
109+
110+
// Note: If the rhs is later found to be < 0 the following two constraints
111+
// will contradict each other (and lead to the final constraints set
112+
// becoming empty). If the sign of the rhs is never specified the bound
113+
// will assume it is positive.
114+
115+
// Upper bound: rhs - (lhs % rhs) - 1 >= 0 i.e. lhs % rhs < rhs
116+
// This only holds if the rhs is > 0.
117+
SmallVector<int64_t, 8> resultUpperBound(rhs);
118+
resultUpperBound.insert(resultUpperBound.end() - 1, -1);
119+
--resultUpperBound.back();
120+
localVarCst.addInequality(resultUpperBound);
121+
122+
// Lower bound: lhs % rhs >= 0 (always holds)
123+
SmallVector<int64_t, 8> resultLowerBound(rhs.size());
124+
resultLowerBound.insert(resultLowerBound.end() - 1, 1);
125+
localVarCst.addInequality(resultLowerBound);
126+
127+
return success();
90128
}
129+
91130
// TODO: Support other semi-affine expressions.
92-
unhandledSemiAffineExpressions = true;
131+
return failure();
93132
}
94-
95-
bool addConservativeSemiAffineBounds = false;
96-
bool unhandledSemiAffineExpressions = false;
97133
};
98134

99135
} // namespace
@@ -114,27 +150,34 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
114150
return success();
115151
}
116152

117-
AffineExprFlattener flattener(numDims, numSymbols,
118-
addConservativeSemiAffineBounds);
119-
// Use the same flattener to simplify each expression successively. This way
120-
// local variables / expressions are shared.
121-
for (auto expr : exprs) {
122-
auto flattenResult = flattener.walkPostOrder(expr);
123-
if (failed(flattenResult))
124-
return failure();
125-
if (flattener.hasUnhandledSemiAffineExpressions())
126-
return failure();
127-
}
153+
auto flattenExprs =
154+
[&](AffineExprFlattenerWithLocalVars &flattener) -> LogicalResult {
155+
// Use the same flattener to simplify each expression successively. This way
156+
// local variables / expressions are shared.
157+
for (auto expr : exprs) {
158+
auto flattenResult = flattener.walkPostOrder(expr);
159+
if (failed(flattenResult))
160+
return failure();
161+
}
162+
163+
assert(flattener.operandExprStack.size() == exprs.size());
164+
flattenedExprs->clear();
165+
flattenedExprs->assign(flattener.operandExprStack.begin(),
166+
flattener.operandExprStack.end());
128167

129-
assert(flattener.operandExprStack.size() == exprs.size());
130-
flattenedExprs->clear();
131-
flattenedExprs->assign(flattener.operandExprStack.begin(),
132-
flattener.operandExprStack.end());
168+
if (localVarCst)
169+
localVarCst->clearAndCopyFrom(flattener.localVarCst);
170+
171+
return success();
172+
};
133173

134-
if (localVarCst)
135-
localVarCst->clearAndCopyFrom(flattener.localVarCst);
174+
if (addConservativeSemiAffineBounds) {
175+
SemiAffineExprFlattener flattener(numDims, numSymbols);
176+
return flattenExprs(flattener);
177+
}
136178

137-
return success();
179+
AffineExprFlattener flattener(numDims, numSymbols);
180+
return flattenExprs(flattener);
138181
}
139182

140183
// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,37 +1521,6 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<MPInt> dividend,
15211521
getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
15221522
}
15231523

1524-
/// Adds a new local variable as the mod of an affine function of other
1525-
/// variables. The coefficients of the operands of the mod are provided in `lhs`
1526-
/// and `rhs` respectively. Three constraints are added to provide a
1527-
/// conservative bound for the mod:
1528-
/// 1. rhs > 0 (assumption/precondition)
1529-
/// 2. lhs % rhs < rhs
1530-
/// 3. lhs % rhs >= 0
1531-
/// We ensure the rhs is positive so we can assume the result is positive.
1532-
void IntegerRelation::addLocalModConservativeBounds(ArrayRef<MPInt> lhs,
1533-
ArrayRef<MPInt> rhs) {
1534-
appendVar(VarKind::Local);
1535-
1536-
// Ensure the rhs is > 0 (most likely case).
1537-
// If this constraint does not hold the following bounds are incorrect.
1538-
SmallVector<MPInt, 8> rhsCopy(rhs);
1539-
rhsCopy.insert(rhsCopy.end() - 1, MPInt(0));
1540-
rhsCopy.back() -= MPInt(1);
1541-
addInequality(rhsCopy);
1542-
1543-
// rhs - (lhs % rhs) - 1 >= 0 i.e. lhs % rhs < rhs
1544-
SmallVector<MPInt, 8> resultUpperBound(rhs);
1545-
resultUpperBound.insert(resultUpperBound.end() - 1, MPInt(-1));
1546-
resultUpperBound.back() -= MPInt(1);
1547-
addInequality(resultUpperBound);
1548-
1549-
// lhs % rhs >= 0
1550-
SmallVector<MPInt, 8> resultLowerBound(rhs.size());
1551-
resultLowerBound.insert(resultLowerBound.end() - 1, MPInt(1));
1552-
addInequality(resultLowerBound);
1553-
}
1554-
15551524
/// Finds an equality that equates the specified variable to a constant.
15561525
/// Returns the position of the equality row. If 'symbolic' is set to true,
15571526
/// symbols are also treated like a constant, i.e., an affine function of the

mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
10-
1110
#include "mlir/Dialect/Vector/IR/VectorOps.h"
12-
#include "llvm/Support/Debug.h"
13-
1411
namespace mlir::vector {
1512

1613
FailureOr<ConstantOrScalableBound::BoundSize>

mlir/lib/IR/AffineExpr.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,8 +1248,7 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
12481248
localExprs, context);
12491249
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
12501250
localExprs, context);
1251-
addLocalVariableSemiAffine(a * b, mulLhs, rhs, lhs, lhs.size());
1252-
return success();
1251+
return addLocalVariableSemiAffine(a * b, mulLhs, rhs, lhs, lhs.size());
12531252
}
12541253

12551254
// Get the RHS constant.
@@ -1302,8 +1301,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
13021301
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
13031302
localExprs, context);
13041303
AffineExpr modExpr = dividendExpr % divisorExpr;
1305-
addLocalVariableSemiAffine(modExpr, modLhs, rhs, lhs, lhs.size());
1306-
return success();
1304+
return addLocalVariableSemiAffine(modExpr, modLhs, rhs, lhs, lhs.size());
13071305
}
13081306

13091307
int64_t rhsConst = rhs[getConstantIndex()];
@@ -1387,19 +1385,22 @@ SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
13871385
return success();
13881386
}
13891387

1390-
void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1388+
LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
13911389
AffineExpr expr, ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
13921390
SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
13931391
assert(result.size() == resultSize &&
13941392
"`result` vector passed is not of correct size");
13951393
int loc;
1396-
if ((loc = findLocalId(expr)) == -1)
1397-
addLocalIdSemiAffine(expr, lhs, rhs);
1394+
if ((loc = findLocalId(expr)) == -1) {
1395+
if (failed(addLocalIdSemiAffine(lhs, rhs, expr)))
1396+
return failure();
1397+
}
13981398
std::fill(result.begin(), result.end(), 0);
13991399
if (loc == -1)
14001400
result[getLocalVarStartIndex() + numLocals - 1] = 1;
14011401
else
14021402
result[getLocalVarStartIndex() + loc] = 1;
1403+
return success();
14031404
}
14041405

14051406
// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
@@ -1434,8 +1435,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
14341435
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
14351436
localExprs, context);
14361437
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1437-
addLocalVariableSemiAffine(divExpr, divLhs, rhs, lhs, lhs.size());
1438-
return success();
1438+
return addLocalVariableSemiAffine(divExpr, divLhs, rhs, lhs, lhs.size());
14391439
}
14401440

14411441
// This is a pure affine expr; the RHS is a positive constant.
@@ -1506,14 +1506,14 @@ void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
15061506
// dividend and divisor are not used here; an override of this method uses it.
15071507
}
15081508

1509-
void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr,
1510-
ArrayRef<int64_t> lhs,
1511-
ArrayRef<int64_t> rhs) {
1509+
LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
1510+
ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
15121511
for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
15131512
subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
15141513
localExprs.push_back(localExpr);
15151514
++numLocals;
15161515
// lhs and rhs are not used here; an override of this method uses them.
1516+
return success();
15171517
}
15181518

15191519
int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {

0 commit comments

Comments
 (0)