Skip to content

Commit 62841f2

Browse files
committed
[mlir][affine][Analysis] Add conservative bounds for semi-affine mods
This path adds support for computing bounds for semi-affine mod expression to FlatLinearConstraints. This is then enabled within the ScalableValueBoundsConstraintSet to allow computing the bounds of scalable remainder loops. E.g. computing the bound of something like: ``` %0 = affine.apply #remainder_start_index()[%c8_vscale] scf.for %i = %0 to %c1000 step %c8_vscale { %remaining_iterations = affine.apply #remaining_iterations(%i) // The upper bound for the remainder loop iterations should be: // %c8_vscale - 1 (expressed as an affine map, // affine_map<()[s0] -> (s0 * 8 - 1)>, where s0 is vscale) %bound = "test.reify_bound"(%remaining_iterations) <{scalable, ...}> } ``` There are caveats to this implementation. To be able to add a bound for a `mod` we need to assume the rhs is positive (> 0). This may not be known when adding the bounds for the `mod` expression. So to handle this a constraint is added for `rhs > 0`, this may later be found not to hold (in which case the constraints set becomes empty/invalid). This is not a problem for computing scalable bounds where it's safe to assume `s0` is vscale (or some positive multiple of it). But this may need to be considered when enabling this feature elsewhere (to ensure correctness).
1 parent 74ed79f commit 62841f2

File tree

11 files changed

+257
-59
lines changed

11 files changed

+257
-59
lines changed

mlir/include/mlir/Analysis/FlatLinearValueConstraints.h

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
6666
/// Return the kind of this object.
6767
Kind getKind() const override { return Kind::FlatLinearConstraints; }
6868

69+
/// Flag to control if conservative semi-affine bounds should be added in
70+
/// `addBound()`.
71+
enum class AddConservativeSemiAffineBounds { No = 0, Yes };
72+
6973
/// Adds a bound for the variable at the specified position with constraints
7074
/// being drawn from the specified bound map. In case of an EQ bound, the
7175
/// bound map is expected to have exactly one result. In case of a LB/UB, the
@@ -77,21 +81,39 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
7781
/// as a closed bound by +1/-1 respectively. In case of an EQ bound, it can
7882
/// only be added as a closed bound.
7983
///
84+
/// Conservative bounds for semi-affine expressions will be added if
85+
/// `AddConservativeSemiAffineBounds` is set to `Yes`. This currently does not
86+
/// cover all semi-affine expressions, so `addBound()` still may fail with
87+
/// this set. Note: If enabled it is possible for the resulting constraint set
88+
/// to become empty if a precondition of a conservative bound is found not to
89+
/// hold.
90+
///
8091
/// Note: The dimensions/symbols of this FlatLinearConstraints must match the
8192
/// dimensions/symbols of the affine map.
82-
LogicalResult addBound(presburger::BoundType type, unsigned pos,
83-
AffineMap boundMap, bool isClosedBound);
93+
LogicalResult addBound(
94+
presburger::BoundType type, unsigned pos, AffineMap boundMap,
95+
bool isClosedBound,
96+
AddConservativeSemiAffineBounds = AddConservativeSemiAffineBounds::No);
8497

8598
/// Adds a bound for the variable at the specified position with constraints
8699
/// being drawn from the specified bound map. In case of an EQ bound, the
87100
/// bound map is expected to have exactly one result. In case of a LB/UB, the
88101
/// bound map may have more than one result, for each of which an inequality
89102
/// is added.
103+
///
104+
/// Conservative bounds for semi-affine expressions will be added if
105+
/// `AddConservativeSemiAffineBounds` is set to `Yes`. This currently does not
106+
/// cover all semi-affine expressions, so `addBound()` still may fail with
107+
/// this set. If enabled it is possible for the resulting constraint set
108+
/// to become empty if a precondition of a conservative bound is found not to
109+
/// hold.
110+
///
90111
/// Note: The dimensions/symbols of this FlatLinearConstraints must match the
91112
/// dimensions/symbols of the affine map. By default the lower bound is closed
92113
/// and the upper bound is open.
93-
LogicalResult addBound(presburger::BoundType type, unsigned pos,
94-
AffineMap boundMap);
114+
LogicalResult addBound(
115+
presburger::BoundType type, unsigned pos, AffineMap boundMap,
116+
AddConservativeSemiAffineBounds = AddConservativeSemiAffineBounds::No);
95117

96118
/// The `addBound` overload above hides the inherited overloads by default, so
97119
/// we explicitly introduce them here.
@@ -193,7 +215,8 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
193215
/// Note: This is a shared helper function of `addLowerOrUpperBound` and
194216
/// `composeMatchingMap`.
195217
LogicalResult flattenAlignedMapAndMergeLocals(
196-
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs);
218+
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
219+
bool addConservativeSemiAffineBounds = false);
197220

198221
/// Prints the number of constraints, dimensions, symbols and locals in the
199222
/// FlatLinearConstraints. Also, prints for each variable whether there is
@@ -468,18 +491,19 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
468491
/// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the
469492
/// dimensions, symbols, and additional variables that represent floor divisions
470493
/// of dimensions, symbols, and in turn other floor divisions. Returns failure
471-
/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled).
494+
/// if 'expr' could not be flattened (i.e., an unhandled semi-affine was found).
472495
/// 'cst' contains constraints that connect newly introduced local variables
473496
/// to existing dimensional and symbolic variables. See documentation for
474497
/// AffineExprFlattener on how mod's and div's are flattened.
475-
LogicalResult getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
476-
unsigned numSymbols,
477-
SmallVectorImpl<int64_t> *flattenedExpr,
478-
FlatLinearConstraints *cst = nullptr);
498+
LogicalResult
499+
getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
500+
SmallVectorImpl<int64_t> *flattenedExpr,
501+
FlatLinearConstraints *cst = nullptr,
502+
bool addConservativeSemiAffineBounds = false);
479503

480504
/// Flattens the result expressions of the map to their corresponding flattened
481505
/// forms and set in 'flattenedExprs'. Returns failure if any expression in the
482-
/// map could not be flattened (i.e., semi-affine is not yet handled). 'cst'
506+
/// map could not be flattened (i.e., an unhandled semi-affine was found). 'cst'
483507
/// contains constraints that connect newly introduced local variables to
484508
/// existing dimensional and / symbolic variables. See documentation for
485509
/// AffineExprFlattener on how mod's and div's are flattened. For all affine
@@ -490,7 +514,8 @@ LogicalResult getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
490514
LogicalResult
491515
getFlattenedAffineExprs(AffineMap map,
492516
std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
493-
FlatLinearConstraints *cst = nullptr);
517+
FlatLinearConstraints *cst = nullptr,
518+
bool addConservativeSemiAffineBounds = false);
494519
LogicalResult
495520
getFlattenedAffineExprs(IntegerSet set,
496521
std::vector<SmallVector<int64_t, 8>> *flattenedExprs,

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,20 @@ class IntegerRelation {
454454
addLocalFloorDiv(getMPIntVec(dividend), MPInt(divisor));
455455
}
456456

457+
/// Adds a new local variable as the mod of an affine function of other
458+
/// variables. The coefficients of the operands of the mod are provided in
459+
/// `lhs` and `rhs` respectively. Three constraints are added to provide a
460+
/// conservative bound for the mod:
461+
/// 1. rhs > 0 (assumption/precondition)
462+
/// 2. lhs % rhs < rhs
463+
/// 3. lhs % rhs >= 0
464+
/// We ensure the rhs is positive so we can assume the result is positive.
465+
void addLocalModConservativeBounds(ArrayRef<MPInt> lhs, ArrayRef<MPInt> rhs);
466+
void addLocalModConservativeBounds(ArrayRef<int64_t> lhs,
467+
ArrayRef<int64_t> rhs) {
468+
addLocalModConservativeBounds(getMPIntVec(lhs), getMPIntVec(rhs));
469+
}
470+
457471
/// Projects out (aka eliminates) `num` variables starting at position
458472
/// `pos`. The resulting constraint system is the shadow along the dimensions
459473
/// that still exist. This method may not always be integer exact.

mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ struct ScalableValueBoundsConstraintSet
3333
MLIRContext *context,
3434
ValueBoundsConstraintSet::StopConditionFn stopCondition,
3535
unsigned vscaleMin, unsigned vscaleMax)
36-
: RTTIExtends(context, stopCondition), vscaleMin(vscaleMin),
37-
vscaleMax(vscaleMax) {};
36+
: RTTIExtends(context, stopCondition,
37+
/*addConservativeSemiAffineBounds=*/true),
38+
vscaleMin(vscaleMin), vscaleMax(vscaleMax){};
3839

3940
using RTTIExtends::bound;
4041
using RTTIExtends::StopConditionFn;

mlir/include/mlir/IR/AffineExprVisitor.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ class SimpleAffineExprFlattener
413413
/// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
414414
/// symbolic rhs expression. `localExpr` is the simplified tree expression
415415
/// (AffineExpr) corresponding to the quantifier.
416-
virtual void addLocalIdSemiAffine(AffineExpr localExpr);
416+
virtual void addLocalIdSemiAffine(AffineExpr localExpr, ArrayRef<int64_t> lhs,
417+
ArrayRef<int64_t> rhs);
417418

418419
private:
419420
/// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
@@ -422,7 +423,8 @@ class SimpleAffineExprFlattener
422423
/// quantifier is already present, we put the coefficient in the proper index
423424
/// of `result`, otherwise we add a new local variable and put the coefficient
424425
/// there.
425-
void addLocalVariableSemiAffine(AffineExpr expr,
426+
void addLocalVariableSemiAffine(AffineExpr expr, ArrayRef<int64_t> lhs,
427+
ArrayRef<int64_t> rhs,
426428
SmallVectorImpl<int64_t> &result,
427429
unsigned long resultSize);
428430

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ class ValueBoundsConstraintSet
313313
/// An index-typed value or the dimension of a shaped-type value.
314314
using ValueDim = std::pair<Value, int64_t>;
315315

316-
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
316+
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition,
317+
bool addConservativeSemiAffineBounds = false);
317318

318319
/// Return "true" if, based on the current state of the constraint system,
319320
/// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
@@ -404,6 +405,9 @@ class ValueBoundsConstraintSet
404405

405406
/// The current stop condition function.
406407
StopConditionFn stopCondition = nullptr;
408+
409+
/// Should conservative bounds be added for semi-affine expressions.
410+
bool addConservativeSemiAffineBounds = false;
407411
};
408412

409413
} // namespace mlir

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,15 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
4646
// inequalities.
4747
IntegerPolyhedron localVarCst;
4848

49-
AffineExprFlattener(unsigned nDims, unsigned nSymbols)
49+
AffineExprFlattener(unsigned nDims, unsigned nSymbols,
50+
bool addConservativeSemiAffineBounds = false)
5051
: SimpleAffineExprFlattener(nDims, nSymbols),
51-
localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {}
52+
localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)),
53+
addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {}
54+
55+
bool hasUnhandledSemiAffineExpressions() const {
56+
return unhandledSemiAffineExpressions;
57+
}
5258

5359
private:
5460
// Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
@@ -63,35 +69,61 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
6369
// Update localVarCst.
6470
localVarCst.addLocalFloorDiv(dividend, divisor);
6571
}
72+
73+
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
74+
// expr) when the rhs is a symbolic expression. The local identifier added
75+
// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
76+
// function of other identifiers, coefficients of which are specified in the
77+
// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
78+
// symbolic rhs expression. `localExpr` is the simplified tree expression
79+
// (AffineExpr) corresponding to the quantifier.
80+
void addLocalIdSemiAffine(AffineExpr localExpr, ArrayRef<int64_t> lhs,
81+
ArrayRef<int64_t> rhs) override {
82+
SimpleAffineExprFlattener::addLocalIdSemiAffine(localExpr, lhs, rhs);
83+
if (!addConservativeSemiAffineBounds) {
84+
unhandledSemiAffineExpressions = true;
85+
return;
86+
}
87+
if (localExpr.getKind() == AffineExprKind::Mod) {
88+
localVarCst.addLocalModConservativeBounds(lhs, rhs);
89+
return;
90+
}
91+
// TODO: Support other semi-affine expressions.
92+
unhandledSemiAffineExpressions = true;
93+
}
94+
95+
bool addConservativeSemiAffineBounds = false;
96+
bool unhandledSemiAffineExpressions = false;
6697
};
6798

6899
} // namespace
69100

70101
// Flattens the expressions in map. Returns failure if 'expr' was unable to be
71102
// flattened. For example two specific cases:
72-
// 1. semi-affine expressions not handled yet.
103+
// 1. an unhandled semi-affine expressions is found.
73104
// 2. has poison expression (i.e., division by zero).
74105
static LogicalResult
75106
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
76107
unsigned numSymbols,
77108
std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
78-
FlatLinearConstraints *localVarCst) {
109+
FlatLinearConstraints *localVarCst,
110+
bool addConservativeSemiAffineBounds = false) {
79111
if (exprs.empty()) {
80112
if (localVarCst)
81113
*localVarCst = FlatLinearConstraints(numDims, numSymbols);
82114
return success();
83115
}
84116

85-
AffineExprFlattener flattener(numDims, numSymbols);
117+
AffineExprFlattener flattener(numDims, numSymbols,
118+
addConservativeSemiAffineBounds);
86119
// Use the same flattener to simplify each expression successively. This way
87120
// local variables / expressions are shared.
88121
for (auto expr : exprs) {
89-
if (!expr.isPureAffine())
90-
return failure();
91-
// has poison expression
92122
auto flattenResult = flattener.walkPostOrder(expr);
93123
if (failed(flattenResult))
94124
return failure();
125+
if (flattener.hasUnhandledSemiAffineExpressions())
126+
return failure();
95127
}
96128

97129
assert(flattener.operandExprStack.size() == exprs.size());
@@ -106,33 +138,33 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
106138
}
107139

108140
// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
109-
// be flattened (semi-affine expressions not handled yet).
110-
LogicalResult
111-
mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
112-
unsigned numSymbols,
113-
SmallVectorImpl<int64_t> *flattenedExpr,
114-
FlatLinearConstraints *localVarCst) {
141+
// be flattened (an unhandled semi-affine was found).
142+
LogicalResult mlir::getFlattenedAffineExpr(
143+
AffineExpr expr, unsigned numDims, unsigned numSymbols,
144+
SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst,
145+
bool addConservativeSemiAffineBounds) {
115146
std::vector<SmallVector<int64_t, 8>> flattenedExprs;
116-
LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
117-
&flattenedExprs, localVarCst);
147+
LogicalResult ret =
148+
::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs,
149+
localVarCst, addConservativeSemiAffineBounds);
118150
*flattenedExpr = flattenedExprs[0];
119151
return ret;
120152
}
121153

122154
/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
123-
/// flattened (i.e., semi-affine expressions not handled yet).
155+
/// flattened (i.e., an unhandled semi-affine was found).
124156
LogicalResult mlir::getFlattenedAffineExprs(
125157
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
126-
FlatLinearConstraints *localVarCst) {
158+
FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) {
127159
if (map.getNumResults() == 0) {
128160
if (localVarCst)
129161
*localVarCst =
130162
FlatLinearConstraints(map.getNumDims(), map.getNumSymbols());
131163
return success();
132164
}
133-
return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
134-
map.getNumSymbols(), flattenedExprs,
135-
localVarCst);
165+
return ::getFlattenedAffineExprs(
166+
map.getResults(), map.getNumDims(), map.getNumSymbols(), flattenedExprs,
167+
localVarCst, addConservativeSemiAffineBounds);
136168
}
137169

138170
LogicalResult mlir::getFlattenedAffineExprs(
@@ -641,9 +673,11 @@ void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
641673
}
642674

643675
LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
644-
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
676+
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
677+
bool addConservativeSemiAffineBounds) {
645678
FlatLinearConstraints localCst;
646-
if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
679+
if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst,
680+
addConservativeSemiAffineBounds))) {
647681
LLVM_DEBUG(llvm::dbgs()
648682
<< "composition unimplemented for semi-affine maps\n");
649683
return failure();
@@ -664,9 +698,9 @@ LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
664698
return success();
665699
}
666700

667-
LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
668-
AffineMap boundMap,
669-
bool isClosedBound) {
701+
LogicalResult FlatLinearConstraints::addBound(
702+
BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound,
703+
AddConservativeSemiAffineBounds addSemiAffineBounds) {
670704
assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
671705
assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
672706
assert(pos < getNumDimAndSymbolVars() && "invalid position");
@@ -680,7 +714,9 @@ LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
680714
bool lower = type == BoundType::LB || type == BoundType::EQ;
681715

682716
std::vector<SmallVector<int64_t, 8>> flatExprs;
683-
if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
717+
if (failed(flattenAlignedMapAndMergeLocals(
718+
boundMap, &flatExprs,
719+
addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes)))
684720
return failure();
685721
assert(flatExprs.size() == boundMap.getNumResults());
686722

@@ -716,9 +752,11 @@ LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
716752
return success();
717753
}
718754

719-
LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
720-
AffineMap boundMap) {
721-
return addBound(type, pos, boundMap, /*isClosedBound=*/type != BoundType::UB);
755+
LogicalResult FlatLinearConstraints::addBound(
756+
BoundType type, unsigned pos, AffineMap boundMap,
757+
AddConservativeSemiAffineBounds addSemiAffineBounds) {
758+
return addBound(type, pos, boundMap,
759+
/*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds);
722760
}
723761

724762
/// Compute an explicit representation for local vars. For all systems coming
@@ -1243,7 +1281,8 @@ mlir::getMultiAffineFunctionFromMap(AffineMap map,
12431281
"AffineMap cannot produce divs without local representation");
12441282

12451283
// TODO: We shouldn't have to do this conversion.
1246-
Matrix<MPInt> mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1);
1284+
Matrix<MPInt> mat(map.getNumResults(),
1285+
map.getNumInputs() + divs.getNumDivs() + 1);
12471286
for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
12481287
for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
12491288
mat(i, j) = flattenedExprs[i][j];

0 commit comments

Comments
 (0)