Skip to content

[MLIR][Affine] Improve memref region bounding size and shape computation #129009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,31 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
/// we explicitly introduce them here.
using IntegerPolyhedron::addBound;

/// Returns a non-negative constant bound on the extent (upper bound - lower
/// bound) of the specified variable if it is found to be a constant; returns
/// std::nullopt if it's not a constant. This method treats symbolic
/// variables specially, i.e., it looks for constant differences between
/// affine expressions involving only the symbolic variables. 'lb', if
/// provided, is set to the lower bound map associated with the constant
/// difference, and similarly, `ub` to the upper bound. Note that 'lb', 'ub'
/// are purely symbolic and will correspond to the symbolic variables of the
/// constaint set.
// Egs: 0 <= i <= 15, return 16.
// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
// ceil(s0 - 7 / 8) = floor(s0 / 8)).
/// The difference between this method and
/// IntegerRelation::getConstantBoundOnDimSize is that unlike the latter, this
/// makes use of affine expressions and maps in its inference and provides
/// output with affine maps; it thus handles local variables by detecting them
/// as affine functions of the symbols when possible.
std::optional<int64_t>
getConstantBoundOnDimSize(MLIRContext *context, unsigned pos,
AffineMap *lb = nullptr, AffineMap *ub = nullptr,
unsigned *minLbPos = nullptr,
unsigned *minUbPos = nullptr) const;

/// Returns the constraint system as an integer set. Returns a null integer
/// set if the system has no constraints, or if an integer set couldn't be
/// constructed as a result of a local variable's explicit representation not
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ class IntegerRelation {
/// intersection with no simplification of any sort attempted.
void append(const IntegerRelation &other);

/// Finds an equality that equates the specified variable to a constant.
/// Returns the position of the equality row. If 'symbolic' is set to true,
/// symbols are also treated like a constant, i.e., an affine function of the
/// symbols is also treated like a constant. Returns -1 if such an equality
/// could not be found.
int findEqualityToConstant(unsigned pos, bool symbolic = false) const;

/// Return the intersection of the two relations.
/// If there are locals, they will be merged.
IntegerRelation intersect(IntegerRelation other) const;
Expand Down
30 changes: 12 additions & 18 deletions mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ struct MemRefRegion {
/// to slice operands (which correspond to symbols).
/// If 'addMemRefDimBounds' is true, constant upper/lower bounds
/// [0, memref.getDimSize(i)) are added for each MemRef dimension 'i'.
/// If `dropLocalVars` is true, all local variables in `cst` are projected
/// out.
///
/// For example, the memref region for this operation at loopDepth = 1 will
/// be:
Expand All @@ -513,9 +515,14 @@ struct MemRefRegion {
/// {memref = %A, write = false, {%i <= m0 <= %i + 7} }
/// The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
///
/// If `dropOuterIVs` is true, project out any IVs other than those among
/// `loopDepth` surrounding IVs, which would be symbols. If `dropOuterIVs`
/// is false, the IVs would be turned into local variables instead of being
/// projected out.
LogicalResult compute(Operation *op, unsigned loopDepth,
const ComputationSliceState *sliceState = nullptr,
bool addMemRefDimBounds = true);
bool addMemRefDimBounds = true,
bool dropLocalVars = true, bool dropOuterIVs = true);

FlatAffineValueConstraints *getConstraints() { return &cst; }
const FlatAffineValueConstraints *getConstraints() const { return &cst; }
Expand All @@ -530,31 +537,18 @@ struct MemRefRegion {
/// corresponding dimension-wise bounds major to minor. The number of elements
/// and all the dimension-wise bounds are guaranteed to be non-negative. We
/// use int64_t instead of uint64_t since index types can be at most
/// int64_t. `lbs` are set to the lower bounds for each of the rank
/// dimensions, and lbDivisors contains the corresponding denominators for
/// floorDivs.
/// int64_t. `lbs` are set to the lower bound maps for each of the rank
/// dimensions where each of these maps is purely symbolic in the constraints
/// set's symbols.
std::optional<int64_t> getConstantBoundingSizeAndShape(
SmallVectorImpl<int64_t> *shape = nullptr,
std::vector<SmallVector<int64_t, 4>> *lbs = nullptr,
SmallVectorImpl<int64_t> *lbDivisors = nullptr) const;
SmallVectorImpl<AffineMap> *lbs = nullptr) const;

/// Gets the lower and upper bound map for the dimensional variable at
/// `pos`.
void getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
AffineMap &ubMap) const;

/// A wrapper around FlatAffineValueConstraints::getConstantBoundOnDimSize().
/// 'pos' corresponds to the position of the memref shape's dimension (major
/// to minor) which matches 1:1 with the dimensional variable positions in
/// 'cst'.
std::optional<int64_t>
getConstantBoundOnDimSize(unsigned pos,
SmallVectorImpl<int64_t> *lb = nullptr,
int64_t *lbFloorDivisor = nullptr) const {
assert(pos < getRank() && "invalid position");
return cst.getConstantBoundOnDimSize64(pos, lb);
}

/// Returns the size of this MemRefRegion in bytes.
std::optional<int64_t> getRegionSize();

Expand Down
258 changes: 226 additions & 32 deletions mlir/lib/Analysis/FlatLinearValueConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,49 @@ std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
return {lbMap, ubMap};
}

/// Express the pos^th identifier of `cst` as an affine expression in
/// terms of other identifiers, if they are available in `exprs`, using the
/// equality at position `idx` in `cs`t. Populates `exprs` with such an
/// expression if possible, and return true. Returns false otherwise.
static bool detectAsExpr(const FlatLinearConstraints &cst, unsigned pos,
unsigned idx, MLIRContext *context,
SmallVectorImpl<AffineExpr> &exprs) {
// Initialize with a `0` expression.
auto expr = getAffineConstantExpr(0, context);

// Traverse `idx`th equality and construct the possible affine expression in
// terms of known identifiers.
unsigned j, e;
for (j = 0, e = cst.getNumVars(); j < e; ++j) {
if (j == pos)
continue;
int64_t c = cst.atEq64(idx, j);
if (c == 0)
continue;
// If any of the involved IDs hasn't been found yet, we can't proceed.
if (!exprs[j])
break;
expr = expr + exprs[j] * c;
}
if (j < e)
// Can't construct expression as it depends on a yet uncomputed
// identifier.
return false;

// Add constant term to AffineExpr.
expr = expr + cst.atEq64(idx, cst.getNumVars());
int64_t vPos = cst.atEq64(idx, pos);
assert(vPos != 0 && "expected non-zero here");
if (vPos > 0)
expr = (-expr).floorDiv(vPos);
else
// vPos < 0.
expr = expr.floorDiv(-vPos);
// Successfully constructed expression.
exprs[pos] = expr;
return true;
}

/// Compute a representation of `num` identifiers starting at `offset` in `cst`
/// as affine expressions involving other known identifiers. Each identifier's
/// expression (in terms of known identifiers) is populated into `memo`.
Expand Down Expand Up @@ -636,41 +679,13 @@ static void computeUnknownVars(const FlatLinearConstraints &cst,

// Detect a variable as an expression of other variables.
std::optional<unsigned> idx;
if (!(idx = cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true))) {
if (!(idx = cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true)))
continue;
}

// Build AffineExpr solving for variable 'pos' in terms of all others.
auto expr = getAffineConstantExpr(0, context);
unsigned j, e;
for (j = 0, e = cst.getNumVars(); j < e; ++j) {
if (j == pos)
continue;
int64_t c = cst.atEq64(*idx, j);
if (c == 0)
continue;
// If any of the involved IDs hasn't been found yet, we can't proceed.
if (!memo[j])
break;
expr = expr + memo[j] * c;
}
if (j < e)
// Can't construct expression as it depends on a yet uncomputed
// variable.
if (detectAsExpr(cst, pos, *idx, context, memo)) {
changed = true;
continue;

// Add constant term to AffineExpr.
expr = expr + cst.atEq64(*idx, cst.getNumVars());
int64_t vPos = cst.atEq64(*idx, pos);
assert(vPos != 0 && "expected non-zero here");
if (vPos > 0)
expr = (-expr).floorDiv(vPos);
else
// vPos < 0.
expr = expr.floorDiv(-vPos);
// Successfully constructed expression.
memo[pos] = expr;
changed = true;
}
}
// This loop is guaranteed to reach a fixed point - since once an
// variable's explicit form is computed (in memo[pos]), it's not updated
Expand Down Expand Up @@ -891,6 +906,185 @@ FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
}

/// Given an equality or inequality (`isEquality` used to disambiguate) of `cst`
/// at `idx`, traverse and sum up `AffineExpr`s of all known ids other than the
/// `pos`th. Known `AffineExpr`s are given in `exprs` (unknowns are null). If
/// the equality/inequality contains any unknown id, return None. Otherwise
/// return sum as `AffineExpr`.
static std::optional<AffineExpr> getAsExpr(const FlatLinearConstraints &cst,
unsigned pos, MLIRContext *context,
ArrayRef<AffineExpr> exprs,
unsigned idx, bool isEquality) {
// Initialize with a `0` expression.
auto expr = getAffineConstantExpr(0, context);

SmallVector<int64_t, 8> row =
isEquality ? cst.getEquality64(idx) : cst.getInequality64(idx);

// Traverse `idx`th equality and construct the possible affine expression in
// terms of known identifiers.
unsigned j, e;
for (j = 0, e = cst.getNumVars(); j < e; ++j) {
if (j == pos)
continue;
int64_t c = row[j];
if (c == 0)
continue;
// If any of the involved IDs hasn't been found yet, we can't proceed.
if (!exprs[j])
break;
expr = expr + exprs[j] * c;
}
if (j < e)
// Can't construct expression as it depends on a yet uncomputed
// identifier.
return std::nullopt;

// Add constant term to AffineExpr.
expr = expr + row[cst.getNumVars()];
return expr;
}

std::optional<int64_t> FlatLinearConstraints::getConstantBoundOnDimSize(
MLIRContext *context, unsigned pos, AffineMap *lb, AffineMap *ub,
unsigned *minLbPos, unsigned *minUbPos) const {

assert(pos < getNumDimVars() && "Invalid identifier position");

auto freeOfUnknownLocalVars = [&](ArrayRef<int64_t> cst,
ArrayRef<AffineExpr> whiteListCols) {
for (int i = getNumDimAndSymbolVars(), e = cst.size() - 1; i < e; ++i) {
if (whiteListCols[i] && whiteListCols[i].isSymbolicOrConstant())
continue;
if (cst[i] != 0)
return false;
}
return true;
};

// Detect the necesary local variables first.
SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
(void)computeLocalVars(memo, context);

// Find an equality for 'pos'^th identifier that equates it to some function
// of the symbolic identifiers (+ constant).
int eqPos = findEqualityToConstant(pos, /*symbolic=*/true);
// If the equality involves a local var that can not be expressed as a
// symbolic or constant affine expression, we bail out.
if (eqPos != -1 && freeOfUnknownLocalVars(getEquality64(eqPos), memo)) {
// This identifier can only take a single value.
if (lb && detectAsExpr(*this, pos, eqPos, context, memo)) {
AffineExpr equalityExpr =
simplifyAffineExpr(memo[pos], 0, getNumSymbolVars());
*lb = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), equalityExpr);
if (ub)
*ub = *lb;
}
if (minLbPos)
*minLbPos = eqPos;
if (minUbPos)
*minUbPos = eqPos;
return 1;
}

// Positions of constraints that are lower/upper bounds on the variable.
SmallVector<unsigned, 4> lbIndices, ubIndices;

// Note inequalities that give lower and upper bounds.
getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
/*eqIndices=*/nullptr, /*offset=*/0,
/*num=*/getNumDimVars());

std::optional<int64_t> minDiff = std::nullopt;
unsigned minLbPosition = 0, minUbPosition = 0;
AffineExpr minLbExpr, minUbExpr;

// Traverse each lower bound and upper bound pair, to compute the difference
// between them.
for (unsigned ubPos : ubIndices) {
// Construct sum of all ids other than `pos`th in the given upper bound row.
std::optional<AffineExpr> maybeUbExpr =
getAsExpr(*this, pos, context, memo, ubPos, /*isEquality=*/false);
if (!maybeUbExpr.has_value() || !(*maybeUbExpr).isSymbolicOrConstant())
continue;

// Canonical form of an inequality that constrains the upper bound on
// an id `x_i` is of the form:
// `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` <= -1.
// Therefore the upper bound on `x_i` will be
// `(
// sum(c_j*x_j) where j != i
// +
// c_0
// )
// /
// -(c_i)`. Divison here is a floorDiv.
AffineExpr ubExpr = maybeUbExpr->floorDiv(-atIneq64(ubPos, pos));
assert(-atIneq64(ubPos, pos) > 0 && "invalid upper bound index");

// Go over each lower bound.
for (unsigned lbPos : lbIndices) {
// Construct sum of all ids other than `pos`th in the given lower bound
// row.
std::optional<AffineExpr> maybeLbExpr =
getAsExpr(*this, pos, context, memo, lbPos, /*isEquality=*/false);
if (!maybeLbExpr.has_value() || !(*maybeLbExpr).isSymbolicOrConstant())
continue;

// Canonical form of an inequality that is constraining the lower bound
// on an id `x_i is of the form:
// `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` >= 1.
// Therefore upperBound on `x_i` will be
// `-(
// sum(c_j*x_j) where j != i
// +
// c_0
// )
// /
// c_i`. Divison here is a ceilDiv.
int64_t divisor = atIneq64(lbPos, pos);
// We convert the `ceilDiv` for floordiv with the formula:
// `expr ceildiv divisor is (expr + divisor - 1) floordiv divisor`,
// since uniformly keeping divisons as `floorDiv` helps their
// simplification.
AffineExpr lbExpr = (-(*maybeLbExpr) + divisor - 1).floorDiv(divisor);
assert(atIneq64(lbPos, pos) > 0 && "invalid lower bound index");

AffineExpr difference =
simplifyAffineExpr(ubExpr - lbExpr + 1, 0, getNumSymbolVars());
// If the difference is not constant, ignore the lower bound - upper bound
// pair.
auto constantDiff = dyn_cast<AffineConstantExpr>(difference);
if (!constantDiff)
continue;

int64_t diffValue = constantDiff.getValue();
// This bound is non-negative by definition.
diffValue = std::max<int64_t>(diffValue, 0);
if (!minDiff || diffValue < *minDiff) {
minDiff = diffValue;
minLbPosition = lbPos;
minUbPosition = ubPos;
minLbExpr = lbExpr;
minUbExpr = ubExpr;
}
}
}

// Populate outputs where available and needed.
if (lb && minDiff) {
*lb = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), minLbExpr);
}
if (ub)
*ub = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), minUbExpr);
if (minLbPos)
*minLbPos = minLbPosition;
if (minUbPos)
*minUbPos = minUbPosition;

return minDiff;
}

IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const {
if (getNumConstraints() == 0)
// Return universal set (always true): 0 == 0.
Expand Down
Loading