Skip to content

Commit 5630143

Browse files
committed
[MLIR][Presburger] LexSimplex::addEquality: add equalities as fixed columns
In LexSimplex, instead of adding equalities as a pair of inequalities, add them as a single row, move them into the basis, and keep them there. There will always be a valid basis involving all non-redundant equalities. Such equalities will then be ignored in some other operations, such as when looking for pivot columns. This speeds them up a little bit. More importantly, this is an important precursor patch to adding support for symbolic integer lexmin, as this heuristic can sometimes make a big difference there. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D122165
1 parent 08543a5 commit 5630143

File tree

3 files changed

+76
-18
lines changed

3 files changed

+76
-18
lines changed

mlir/include/mlir/Analysis/Presburger/Simplex.h

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,21 +166,21 @@ class SimplexBase {
166166
/// false otherwise.
167167
bool isEmpty() const;
168168

169-
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
170-
/// is the current number of variables, then the corresponding inequality is
171-
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
172-
virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
173-
174169
/// Returns the number of variables in the tableau.
175170
unsigned getNumVariables() const;
176171

177172
/// Returns the number of constraints in the tableau.
178173
unsigned getNumConstraints() const;
179174

175+
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
176+
/// is the current number of variables, then the corresponding inequality is
177+
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
178+
virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
179+
180180
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
181181
/// is the current number of variables, then the corresponding equality is
182182
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
183-
void addEquality(ArrayRef<int64_t> coeffs);
183+
virtual void addEquality(ArrayRef<int64_t> coeffs) = 0;
184184

185185
/// Add new variables to the end of the list of variables.
186186
void appendVariable(unsigned count = 1);
@@ -249,6 +249,14 @@ class SimplexBase {
249249
/// coefficient for it.
250250
Optional<unsigned> findAnyPivotRow(unsigned col);
251251

252+
/// Return any column that this row can be pivoted with, ignoring tableau
253+
/// consistency. Equality rows are not considered.
254+
///
255+
/// Returns an empty optional if no pivot is possible, which happens only when
256+
/// the column unknown is a variable and no constraint has a non-zero
257+
/// coefficient for it.
258+
Optional<unsigned> findAnyPivotCol(unsigned row);
259+
252260
/// Swap the row with the column in the tableau's data structures but not the
253261
/// tableau itself. This is used by pivot.
254262
void swapRowWithCol(unsigned row, unsigned col);
@@ -295,6 +303,7 @@ class SimplexBase {
295303
RemoveLastVariable,
296304
UnmarkEmpty,
297305
UnmarkLastRedundant,
306+
UnmarkLastEquality,
298307
RestoreBasis
299308
};
300309

@@ -308,13 +317,14 @@ class SimplexBase {
308317
/// Undo the operation represented by the log entry.
309318
void undo(UndoLogEntry entry);
310319

311-
/// Return the number of fixed columns, as described in the constructor above,
312-
/// this is the number of columns beyond those for the variables in var.
313-
unsigned getNumFixedCols() const { return usingBigM ? 3u : 2u; }
320+
unsigned getNumFixedCols() const { return numFixedCols; }
314321

315322
/// Stores whether or not a big M column is present in the tableau.
316323
bool usingBigM;
317324

325+
/// denom + const + maybe M + equality columns
326+
unsigned numFixedCols;
327+
318328
/// The number of rows in the tableau.
319329
unsigned nRow;
320330

@@ -435,9 +445,12 @@ class LexSimplex : public SimplexBase {
435445
///
436446
/// This just adds the inequality to the tableau and does not try to create a
437447
/// consistent tableau configuration.
438-
void addInequality(ArrayRef<int64_t> coeffs) final {
439-
addRow(coeffs, /*makeRestricted=*/true);
440-
}
448+
void addInequality(ArrayRef<int64_t> coeffs) final;
449+
450+
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
451+
/// is the current number of variables, then the corresponding equality is
452+
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
453+
void addEquality(ArrayRef<int64_t> coeffs) final;
441454

442455
/// Get a snapshot of the current state. This is used for rolling back.
443456
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
@@ -533,6 +546,11 @@ class Simplex : public SimplexBase {
533546
/// state and marks the Simplex empty if this is not possible.
534547
void addInequality(ArrayRef<int64_t> coeffs) final;
535548

549+
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
550+
/// is the current number of variables, then the corresponding equality is
551+
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
552+
void addEquality(ArrayRef<int64_t> coeffs) final;
553+
536554
/// Compute the maximum or minimum value of the given row, depending on
537555
/// direction. The specified row is never pivoted. On return, the row may
538556
/// have a negative sample value if the direction is down.

mlir/lib/Analysis/Presburger/Simplex.cpp

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ using Direction = Simplex::Direction;
1919
const int nullIndex = std::numeric_limits<int>::max();
2020

2121
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
22-
: usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
23-
nRedundant(0), tableau(0, nCol), empty(false) {
24-
colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
22+
: usingBigM(mustUseBigM), numFixedCols(mustUseBigM ? 3 : 2), nRow(0),
23+
nCol(numFixedCols + nVar), nRedundant(0), tableau(0, nCol), empty(false) {
24+
colUnknown.insert(colUnknown.begin(), numFixedCols, nullIndex);
2525
for (unsigned i = 0; i < nVar; ++i) {
2626
var.emplace_back(Orientation::Column, /*restricted=*/false,
27-
/*pos=*/getNumFixedCols() + i);
27+
/*pos=*/numFixedCols + i);
2828
colUnknown.push_back(i);
2929
}
3030
}
@@ -309,7 +309,7 @@ void LexSimplex::restoreRationalConsistency() {
309309
// minimizes the change in sample value.
310310
LogicalResult LexSimplex::moveRowUnknownToColumn(unsigned row) {
311311
Optional<unsigned> maybeColumn;
312-
for (unsigned col = 3; col < nCol; ++col) {
312+
for (unsigned col = getNumFixedCols(); col < nCol; ++col) {
313313
if (tableau(row, col) <= 0)
314314
continue;
315315
maybeColumn =
@@ -648,7 +648,7 @@ void Simplex::addInequality(ArrayRef<int64_t> coeffs) {
648648
///
649649
/// We simply add two opposing inequalities, which force the expression to
650650
/// be zero.
651-
void SimplexBase::addEquality(ArrayRef<int64_t> coeffs) {
651+
void Simplex::addEquality(ArrayRef<int64_t> coeffs) {
652652
addInequality(coeffs);
653653
SmallVector<int64_t, 8> negatedCoeffs;
654654
for (int64_t coeff : coeffs)
@@ -705,6 +705,15 @@ Optional<unsigned> SimplexBase::findAnyPivotRow(unsigned col) {
705705
return {};
706706
}
707707

708+
// This doesn't find a pivot column only if the row has zero coefficients for
709+
// every column not marked as an equality.
710+
Optional<unsigned> SimplexBase::findAnyPivotCol(unsigned row) {
711+
for (unsigned col = getNumFixedCols(); col < nCol; ++col)
712+
if (tableau(row, col) != 0)
713+
return col;
714+
return {};
715+
}
716+
708717
// It's not valid to remove the constraint by deleting the column since this
709718
// would result in an invalid basis.
710719
void Simplex::undoLastConstraint() {
@@ -780,6 +789,10 @@ void SimplexBase::undo(UndoLogEntry entry) {
780789
empty = false;
781790
} else if (entry == UndoLogEntry::UnmarkLastRedundant) {
782791
nRedundant--;
792+
} else if (entry == UndoLogEntry::UnmarkLastEquality) {
793+
numFixedCols--;
794+
assert(getNumFixedCols() >= 2 + usingBigM &&
795+
"The denominator, constant, big M and symbols are always fixed!");
783796
} else if (entry == UndoLogEntry::RestoreBasis) {
784797
assert(!savedBases.empty() && "No bases saved!");
785798

@@ -1110,6 +1123,26 @@ Optional<SmallVector<Fraction, 8>> Simplex::getRationalSample() const {
11101123
return sample;
11111124
}
11121125

1126+
void LexSimplex::addInequality(ArrayRef<int64_t> coeffs) {
1127+
addRow(coeffs, /*makeRestricted=*/true);
1128+
}
1129+
1130+
/// Try to make the equality a fixed column by finding any pivot and performing
1131+
/// it. The only time this is not possible is when the given equality's
1132+
/// direction is already in the span of the existing fixed column equalities. In
1133+
/// that case, we just leave it in row position.
1134+
void LexSimplex::addEquality(ArrayRef<int64_t> coeffs) {
1135+
const Unknown &u = con[addRow(coeffs, /*makeRestricted=*/true)];
1136+
Optional<unsigned> pivotCol = findAnyPivotCol(u.pos);
1137+
if (!pivotCol)
1138+
return;
1139+
1140+
pivot(u.pos, *pivotCol);
1141+
swapColumns(*pivotCol, getNumFixedCols());
1142+
numFixedCols++;
1143+
undoLog.push_back(UndoLogEntry::UnmarkLastEquality);
1144+
}
1145+
11131146
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
11141147
if (empty)
11151148
return OptimumKind::Empty;

mlir/unittests/Analysis/Presburger/SimplexTest.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,10 @@ TEST(SimplexTest, addDivisionVariable) {
548548
ASSERT_TRUE(sample.hasValue());
549549
EXPECT_EQ((*sample)[0] / 2, (*sample)[1]);
550550
}
551+
552+
TEST(LexSimplexTest, addEquality) {
553+
IntegerRelation rel(/*numDomain=*/0, /*numRange=*/1);
554+
rel.addEquality({1, 0});
555+
LexSimplex simplex(rel);
556+
EXPECT_EQ(simplex.getNumConstraints(), 1u);
557+
}

0 commit comments

Comments
 (0)