Skip to content

mlir/Presburger: reinstate use of LogicalResult #97415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
#include "mlir/Analysis/Presburger/Utils.h"
#include "llvm/ADT/DynamicAPInt.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include <optional>

namespace mlir {
namespace presburger {
using llvm::DynamicAPInt;
using llvm::failure;
using llvm::int64fromDynamicAPInt;
using llvm::LogicalResult;
using llvm::SmallVectorImpl;
using llvm::success;

class IntegerRelation;
class IntegerPolyhedron;
Expand Down Expand Up @@ -478,7 +482,7 @@ class IntegerRelation {
/// equality detection; if successful, the constant is substituted for the
/// variable everywhere in the constraint system and then removed from the
/// system.
bool constantFoldVar(unsigned pos);
LogicalResult constantFoldVar(unsigned pos);

/// This method calls `constantFoldVar` for the specified range of variables,
/// `num` variables starting at position `pos`.
Expand All @@ -501,7 +505,7 @@ class IntegerRelation {
/// 3) this = {0 <= d0 <= 5, 1 <= d1 <= 9}
/// other = {2 <= d0 <= 6, 5 <= d1 <= 15},
/// output = {0 <= d0 <= 6, 1 <= d1 <= 15}
bool unionBoundingBox(const IntegerRelation &other);
LogicalResult unionBoundingBox(const IntegerRelation &other);

/// Returns the smallest known constant bound for the extent of the specified
/// variable (pos^th), i.e., the smallest known constant that is greater
Expand Down Expand Up @@ -774,8 +778,8 @@ class IntegerRelation {
/// Eliminates a single variable at `position` from equality and inequality
/// constraints. Returns `success` if the variable was eliminated, and
/// `failure` otherwise.
inline bool gaussianEliminateVar(unsigned position) {
return gaussianEliminateVars(position, position + 1) == 1;
inline LogicalResult gaussianEliminateVar(unsigned position) {
return success(gaussianEliminateVars(position, position + 1) == 1);
}

/// Removes local variables using equalities. Each equality is checked if it
Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Analysis/Presburger/Simplex.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class LexSimplexBase : public SimplexBase {
/// lexicopositivity of the basis transform. The row must have a non-positive
/// sample value. If this is not possible, return failure. This occurs when
/// the constraints have no solution or the sample value is zero.
bool moveRowUnknownToColumn(unsigned row);
LogicalResult moveRowUnknownToColumn(unsigned row);

/// Given a row that has a non-integer sample value, add an inequality to cut
/// away this fractional sample value from the polytope without removing any
Expand All @@ -459,7 +459,7 @@ class LexSimplexBase : public SimplexBase {
///
/// Return failure if the tableau became empty, and success if it didn't.
/// Failure status indicates that the polytope was integer empty.
bool addCut(unsigned row);
LogicalResult addCut(unsigned row);

/// Undo the addition of the last constraint. This is only called while
/// rolling back.
Expand Down Expand Up @@ -511,7 +511,7 @@ class LexSimplex : public LexSimplexBase {
MaybeOptimum<SmallVector<Fraction, 8>> getRationalSample() const;

/// Make the tableau configuration consistent.
bool restoreRationalConsistency();
LogicalResult restoreRationalConsistency();

/// Return whether the specified row is violated;
bool rowIsViolated(unsigned row) const;
Expand Down Expand Up @@ -626,7 +626,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
/// Return failure if the tableau became empty, indicating that the polytope
/// is always integer empty in the current symbol domain.
/// Return success otherwise.
bool doNonBranchingPivots();
LogicalResult doNonBranchingPivots();

/// Get a row that is always violated in the current domain, if one exists.
std::optional<unsigned> maybeGetAlwaysViolatedRow();
Expand All @@ -647,7 +647,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
/// at the time of the call. (This function may modify the symbol domain, but
/// failure statu indicates that the polytope was empty for all symbol values
/// in the initial domain.)
bool addSymbolicCut(unsigned row);
LogicalResult addSymbolicCut(unsigned row);

/// Get the numerator of the symbolic sample of the specific row.
/// This is an affine expression in the symbols with integer coefficients.
Expand Down Expand Up @@ -820,7 +820,7 @@ class Simplex : public SimplexBase {
///
/// Returns success if the unknown was successfully restored to a non-negative
/// sample value, failure otherwise.
bool restoreRow(Unknown &u);
LogicalResult restoreRow(Unknown &u);

/// Find a pivot to change the sample value of row in the specified
/// direction while preserving tableau consistency, except that if the
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Analysis/FlatLinearValueConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,10 +1247,10 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox(
if (!areVarsAligned(*this, otherCst)) {
FlatLinearValueConstraints otherCopy(otherCst);
mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy);
return success(IntegerPolyhedron::unionBoundingBox(otherCopy));
return IntegerPolyhedron::unionBoundingBox(otherCopy);
}

return success(IntegerPolyhedron::unionBoundingBox(otherCst));
return IntegerPolyhedron::unionBoundingBox(otherCst);
}

//===----------------------------------------------------------------------===//
Expand Down
26 changes: 14 additions & 12 deletions mlir/lib/Analysis/Presburger/IntegerRelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -1552,22 +1553,22 @@ static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
return -1;
}

bool IntegerRelation::constantFoldVar(unsigned pos) {
LogicalResult IntegerRelation::constantFoldVar(unsigned pos) {
assert(pos < getNumVars() && "invalid position");
int rowIdx;
if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
return false;
return failure();

// atEq(rowIdx, pos) is either -1 or 1.
assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
DynamicAPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
setAndEliminate(pos, constVal);
return true;
return success();
}

void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
if (!constantFoldVar(t))
if (constantFoldVar(t).failed())
t++;
}
}
Expand Down Expand Up @@ -1944,9 +1945,9 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
if (atEq(r, pos) != 0) {
// Use Gaussian elimination here (since we have an equality).
bool ret = gaussianEliminateVar(pos);
LogicalResult ret = gaussianEliminateVar(pos);
(void)ret;
assert(ret && "Gaussian elimination guaranteed to succeed");
assert(ret.succeeded() && "Gaussian elimination guaranteed to succeed");
LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
LLVM_DEBUG(dump());
return;
Expand Down Expand Up @@ -2173,7 +2174,8 @@ static void getCommonConstraints(const IntegerRelation &a,

// Computes the bounding box with respect to 'other' by finding the min of the
// lower bounds and the max of the upper bounds along each of the dimensions.
bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
LogicalResult
IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
assert(getNumLocalVars() == 0 && "local ids not supported yet here");

Expand Down Expand Up @@ -2201,13 +2203,13 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
if (!extent.has_value())
// TODO: symbolic extents when necessary.
// TODO: handle union if a dimension is unbounded.
return false;
return failure();

auto otherExtent = otherCst.getConstantBoundOnDimSize(
d, &otherLb, &otherLbFloorDivisor, &otherUb);
if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
// TODO: symbolic extents when necessary.
return false;
return failure();

assert(lbFloorDivisor > 0 && "divisor always expected to be positive");

Expand All @@ -2227,7 +2229,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
auto constLb = getConstantBound(BoundType::LB, d);
auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
if (!constLb.has_value() || !constOtherLb.has_value())
return false;
return failure();
std::fill(minLb.begin(), minLb.end(), 0);
minLb.back() = std::min(*constLb, *constOtherLb);
}
Expand All @@ -2243,7 +2245,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
auto constUb = getConstantBound(BoundType::UB, d);
auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
if (!constUb.has_value() || !constOtherUb.has_value())
return false;
return failure();
std::fill(maxUb.begin(), maxUb.end(), 0);
maxUb.back() = std::max(*constUb, *constOtherUb);
}
Expand Down Expand Up @@ -2281,7 +2283,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
// union (since the above are just the union along dimensions); we shouldn't
// be discarding any other constraints on the symbols.

return true;
return success();
}

bool IntegerRelation::isColZero(unsigned pos) const {
Expand Down
59 changes: 31 additions & 28 deletions mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <functional>
Expand Down Expand Up @@ -753,18 +754,18 @@ class presburger::SetCoalescer {
/// \___\|/ \_____/
///
///
bool coalescePairCutCase(unsigned i, unsigned j);
LogicalResult coalescePairCutCase(unsigned i, unsigned j);

/// Types the inequality `ineq` according to its `IneqType` for `simp` into
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
bool typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp);
LogicalResult typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp);

/// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and
/// -`eq` >= 0 according to their `IneqType` for `simp` into
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
bool typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp);
LogicalResult typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp);

/// Replaces the element at position `i` with the last element and erases
/// the last element for both `disjuncts` and `simplices`.
Expand All @@ -775,7 +776,7 @@ class presburger::SetCoalescer {
/// successfully coalesced. The simplices in `simplices` need to be the ones
/// constructed from `disjuncts`. At this point, there are no empty
/// disjuncts in `disjuncts` left.
bool coalescePair(unsigned i, unsigned j);
LogicalResult coalescePair(unsigned i, unsigned j);
};

/// Constructs a `SetCoalescer` from a `PresburgerRelation`. Only adds non-empty
Expand Down Expand Up @@ -818,7 +819,7 @@ PresburgerRelation SetCoalescer::coalesce() {
cuttingIneqsB.clear();
if (i == j)
continue;
if (coalescePair(i, j)) {
if (coalescePair(i, j).succeeded()) {
broken = true;
break;
}
Expand Down Expand Up @@ -902,15 +903,15 @@ void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j,
/// \___\|/ \_____/
///
///
bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
/// All inequalities of `b` need to be redundant. We already know that the
/// redundant ones are, so only the cutting ones remain to be checked.
Simplex &simp = simplices[i];
IntegerRelation &disjunct = disjuncts[i];
if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef<DynamicAPInt> curr) {
return !isFacetContained(curr, simp);
}))
return false;
return failure();
IntegerRelation newSet(disjunct.getSpace());

for (ArrayRef<DynamicAPInt> curr : redundantIneqsA)
Expand All @@ -920,23 +921,25 @@ bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
newSet.addInequality(curr);

addCoalescedDisjunct(i, j, newSet);
return true;
return success();
}

bool SetCoalescer::typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp) {
LogicalResult SetCoalescer::typeInequality(ArrayRef<DynamicAPInt> ineq,
Simplex &simp) {
Simplex::IneqType type = simp.findIneqType(ineq);
if (type == Simplex::IneqType::Redundant)
redundantIneqsB.push_back(ineq);
else if (type == Simplex::IneqType::Cut)
cuttingIneqsB.push_back(ineq);
else
return false;
return true;
return failure();
return success();
}

bool SetCoalescer::typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp) {
if (!typeInequality(eq, simp))
return false;
LogicalResult SetCoalescer::typeEquality(ArrayRef<DynamicAPInt> eq,
Simplex &simp) {
if (typeInequality(eq, simp).failed())
return failure();
negEqs.push_back(getNegatedCoeffs(eq));
ArrayRef<DynamicAPInt> inv(negEqs.back());
return typeInequality(inv, simp);
Expand All @@ -951,15 +954,15 @@ void SetCoalescer::eraseDisjunct(unsigned i) {
simplices.pop_back();
}

bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
LogicalResult SetCoalescer::coalescePair(unsigned i, unsigned j) {

IntegerRelation &a = disjuncts[i];
IntegerRelation &b = disjuncts[j];
/// Handling of local ids is not yet implemented, so these cases are
/// skipped.
/// TODO: implement local id support.
if (a.getNumLocalVars() != 0 || b.getNumLocalVars() != 0)
return false;
return failure();
Simplex &simpA = simplices[i];
Simplex &simpB = simplices[j];

Expand All @@ -969,34 +972,34 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
// inequality is encountered during typing, the two IntegerRelations
// cannot be coalesced.
for (int k = 0, e = a.getNumInequalities(); k < e; ++k)
if (!typeInequality(a.getInequality(k), simpB))
return false;
if (typeInequality(a.getInequality(k), simpB).failed())
return failure();

for (int k = 0, e = a.getNumEqualities(); k < e; ++k)
if (!typeEquality(a.getEquality(k), simpB))
return false;
if (typeEquality(a.getEquality(k), simpB).failed())
return failure();

std::swap(redundantIneqsA, redundantIneqsB);
std::swap(cuttingIneqsA, cuttingIneqsB);

for (int k = 0, e = b.getNumInequalities(); k < e; ++k)
if (!typeInequality(b.getInequality(k), simpA))
return false;
if (typeInequality(b.getInequality(k), simpA).failed())
return failure();

for (int k = 0, e = b.getNumEqualities(); k < e; ++k)
if (!typeEquality(b.getEquality(k), simpA))
return false;
if (typeEquality(b.getEquality(k), simpA).failed())
return failure();

// If there are no cutting inequalities of `a`, `b` is contained
// within `a`.
if (cuttingIneqsA.empty()) {
eraseDisjunct(j);
return true;
return success();
}

// Try to apply the cut case
if (coalescePairCutCase(i, j))
return true;
if (coalescePairCutCase(i, j).succeeded())
return success();

// Swap the vectors to compare the pair (j,i) instead of (i,j).
std::swap(redundantIneqsA, redundantIneqsB);
Expand All @@ -1006,7 +1009,7 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
// within `a`.
if (cuttingIneqsA.empty()) {
eraseDisjunct(i);
return true;
return success();
}

// Try to apply the cut case
Expand Down
Loading
Loading