Skip to content

[MLIR] NFC. Improve API signature + clang-tidy warning in IntegerRelation #128993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Original file line number Diff line number Diff line change
Expand Up @@ -738,11 +738,11 @@ class IntegerRelation {
/// Same as findSymbolicIntegerLexMin but produces lexmax instead of lexmin
SymbolicLexOpt findSymbolicIntegerLexMax() const;

/// Searches for a constraint with a non-zero coefficient at `colIdx` in
/// equality (isEq=true) or inequality (isEq=false) constraints.
/// Returns true and sets row found in search in `rowIdx`, false otherwise.
bool findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
unsigned *rowIdx) const;
/// Finds a constraint with a non-zero coefficient at `colIdx` in equality
/// (isEq=true) or inequality (isEq=false) constraints. Returns the position
/// of the row if it was found or none otherwise.
std::optional<unsigned> findConstraintWithNonZeroAt(unsigned colIdx,
bool isEq) const;

/// Return the set difference of this set and the given set, i.e.,
/// return `this \ set`.
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Analysis/FlatLinearValueConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@ static void computeUnknownVars(const FlatLinearConstraints &cst,
}

// Detect a variable as an expression of other variables.
unsigned idx;
if (!cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) {
std::optional<unsigned> idx;
if (!(idx = cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true))) {
continue;
}

Expand All @@ -646,7 +646,7 @@ static void computeUnknownVars(const FlatLinearConstraints &cst,
for (j = 0, e = cst.getNumVars(); j < e; ++j) {
if (j == pos)
continue;
int64_t c = cst.atEq64(idx, j);
int64_t c = cst.atEq64(*idx, j);
if (c == 0)
continue;
// If any of the involved IDs hasn't been found yet, we can't proceed.
Expand All @@ -660,8 +660,8 @@ static void computeUnknownVars(const FlatLinearConstraints &cst,
continue;

// Add constant term to AffineExpr.
expr = expr + cst.atEq64(idx, cst.getNumVars());
int64_t vPos = cst.atEq64(idx, pos);
expr = expr + cst.atEq64(*idx, cst.getNumVars());
int64_t vPos = cst.atEq64(*idx, pos);
assert(vPos != 0 && "expected non-zero here");
if (vPos > 0)
expr = (-expr).floorDiv(vPos);
Expand Down
68 changes: 31 additions & 37 deletions mlir/lib/Analysis/Presburger/IntegerRelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,22 +564,18 @@ void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) {
*this = other;
}

// Searches for a constraint with a non-zero coefficient at `colIdx` in
// equality (isEq=true) or inequality (isEq=false) constraints.
// Returns true and sets row found in search in `rowIdx`, false otherwise.
bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
unsigned *rowIdx) const {
std::optional<unsigned>
IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq) const {
assert(colIdx < getNumCols() && "position out of bounds");
auto at = [&](unsigned rowIdx) -> DynamicAPInt {
return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx);
};
unsigned e = isEq ? getNumEqualities() : getNumInequalities();
for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
if (at(*rowIdx) != 0) {
return true;
}
for (unsigned rowIdx = 0; rowIdx < e; ++rowIdx) {
if (at(rowIdx) != 0)
return rowIdx;
}
return false;
return std::nullopt;
}

void IntegerRelation::normalizeConstraintsByGCD() {
Expand Down Expand Up @@ -1088,31 +1084,30 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
unsigned pivotCol = 0;
for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
// Find a row which has a non-zero coefficient in column 'j'.
unsigned pivotRow;
if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) {
// No pivot row in equalities with non-zero at 'pivotCol'.
if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) {
// If inequalities are also non-zero in 'pivotCol', it can be
// eliminated.
continue;
}
break;
std::optional<unsigned> pivotRow =
findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true);
// No pivot row in equalities with non-zero at 'pivotCol'.
if (!pivotRow) {
// If inequalities are also non-zero in 'pivotCol', it can be eliminated.
if ((pivotRow = findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false)))
break;
continue;
}

// Eliminate variable at 'pivotCol' from each equality row.
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
eliminateFromConstraint(this, i, *pivotRow, pivotCol, posStart,
/*isEq=*/true);
equalities.normalizeRow(i);
}

// Eliminate variable at 'pivotCol' from each inequality row.
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
eliminateFromConstraint(this, i, *pivotRow, pivotCol, posStart,
/*isEq=*/false);
inequalities.normalizeRow(i);
}
removeEquality(pivotRow);
removeEquality(*pivotRow);
gcdTightenInequalities();
}
// Update position limit based on number eliminated.
Expand All @@ -1125,31 +1120,31 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
bool IntegerRelation::gaussianEliminate() {
gcdTightenInequalities();
unsigned firstVar = 0, vars = getNumVars();
unsigned nowDone, eqs, pivotRow;
unsigned nowDone, eqs;
std::optional<unsigned> pivotRow;
for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) {
// Finds the first non-empty column.
for (; firstVar < vars; ++firstVar) {
if (!findConstraintWithNonZeroAt(firstVar, true, &pivotRow))
continue;
break;
if ((pivotRow = findConstraintWithNonZeroAt(firstVar, /*isEq=*/true)))
break;
}
// The matrix has been normalized to row echelon form.
if (firstVar >= vars)
break;

// The first pivot row found is below where it should currently be placed.
if (pivotRow > nowDone) {
equalities.swapRows(pivotRow, nowDone);
pivotRow = nowDone;
if (*pivotRow > nowDone) {
equalities.swapRows(*pivotRow, nowDone);
*pivotRow = nowDone;
}

// Normalize all lower equations and all inequalities.
for (unsigned i = nowDone + 1; i < eqs; ++i) {
eliminateFromConstraint(this, i, pivotRow, firstVar, 0, true);
eliminateFromConstraint(this, i, *pivotRow, firstVar, 0, true);
equalities.normalizeRow(i);
}
for (unsigned i = 0, ineqs = getNumInequalities(); i < ineqs; ++i) {
eliminateFromConstraint(this, i, pivotRow, firstVar, 0, false);
eliminateFromConstraint(this, i, *pivotRow, firstVar, 0, false);
inequalities.normalizeRow(i);
}
gcdTightenInequalities();
Expand Down Expand Up @@ -2290,9 +2285,8 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
}

bool IntegerRelation::isColZero(unsigned pos) const {
unsigned rowPos;
return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) &&
!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos);
return !findConstraintWithNonZeroAt(pos, /*isEq=*/false) &&
!findConstraintWithNonZeroAt(pos, /*isEq=*/true);
}

/// Find positions of inequalities and equalities that do not have a coefficient
Expand Down Expand Up @@ -2600,16 +2594,16 @@ void IntegerRelation::print(raw_ostream &os) const {
for (unsigned j = 0, f = getNumCols(); j < f; ++j)
updatePrintMetrics<DynamicAPInt>(atIneq(i, j), ptm);
// Print using PrintMetrics.
unsigned MIN_SPACING = 1;
constexpr unsigned kMinSpacing = 1;
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
printWithPrintMetrics<DynamicAPInt>(os, atEq(i, j), MIN_SPACING, ptm);
printWithPrintMetrics<DynamicAPInt>(os, atEq(i, j), kMinSpacing, ptm);
}
os << " = 0\n";
}
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
printWithPrintMetrics<DynamicAPInt>(os, atIneq(i, j), MIN_SPACING, ptm);
printWithPrintMetrics<DynamicAPInt>(os, atIneq(i, j), kMinSpacing, ptm);
}
os << " >= 0\n";
}
Expand Down