Skip to content

Commit f08fe1f

Browse files
[MLIR][Presburger] Implement matrix inverse (#67382)
Shift the `determinant()` function from LinearTransform to Matrix. Implement a FracMatrix class, inheriting from Matrix<Fraction>, for inverses. Implement inverse for FracMatrix and intInverse for IntMatrix. Make Matrix internals protected instead of private so that Int/FracMatrix can access them.
1 parent 080fb3e commit f08fe1f

File tree

5 files changed

+241
-23
lines changed

5 files changed

+241
-23
lines changed

mlir/include/mlir/Analysis/Presburger/LinearTransform.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ class LinearTransform {
5050
return matrix.postMultiplyWithColumn(colVec);
5151
}
5252

53-
// Compute the determinant of the transform by converting it to row echelon
54-
// form and then taking the product of the diagonal.
55-
MPInt determinant();
56-
5753
private:
5854
IntMatrix matrix;
5955
};

mlir/include/mlir/Analysis/Presburger/Matrix.h

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ class Matrix {
189189
/// invariants satisfied.
190190
bool hasConsistentState() const;
191191

192-
private:
192+
protected:
193193
/// The current number of rows, columns, and reserved columns. The underlying
194194
/// data vector is viewed as an nRows x nReservedColumns matrix, of which the
195195
/// first nColumns columns are currently in use, and the remaining are
@@ -210,13 +210,7 @@ class IntMatrix : public Matrix<MPInt> {
210210
unsigned reservedColumns = 0)
211211
: Matrix<MPInt>(rows, columns, reservedRows, reservedColumns){};
212212

213-
IntMatrix(Matrix<MPInt> m)
214-
: Matrix<MPInt>(m.getNumRows(), m.getNumColumns(), m.getNumReservedRows(),
215-
m.getNumReservedColumns()) {
216-
for (unsigned i = 0; i < m.getNumRows(); i++)
217-
for (unsigned j = 0; j < m.getNumColumns(); j++)
218-
at(i, j) = m(i, j);
219-
};
213+
IntMatrix(Matrix<MPInt> m) : Matrix<MPInt>(std::move(m)){};
220214

221215
/// Return the identity matrix of the specified dimension.
222216
static IntMatrix identity(unsigned dimension);
@@ -239,6 +233,38 @@ class IntMatrix : public Matrix<MPInt> {
239233
/// Divide the columns of the specified row by their GCD.
240234
/// Returns the GCD of the columns of the specified row.
241235
MPInt normalizeRow(unsigned row);
236+
237+
// Compute the determinant of the matrix (cubic time).
238+
// Stores the integer inverse of the matrix in the pointer
239+
// passed (if any). The pointer is unchanged if the inverse
240+
// does not exist, which happens iff det = 0.
241+
// For a matrix M, the integer inverse is the matrix M' such that
242+
// M x M' = M'  M = det(M) x I.
243+
// Assert-fails if the matrix is not square.
244+
MPInt determinant(IntMatrix *inverse = nullptr) const;
245+
};
246+
247+
// An inherited class for rational matrices, with no new data attributes.
248+
// This class is for functionality that only applies to matrices of fractions.
249+
class FracMatrix : public Matrix<Fraction> {
250+
public:
251+
FracMatrix(unsigned rows, unsigned columns, unsigned reservedRows = 0,
252+
unsigned reservedColumns = 0)
253+
: Matrix<Fraction>(rows, columns, reservedRows, reservedColumns){};
254+
255+
FracMatrix(Matrix<Fraction> m) : Matrix<Fraction>(std::move(m)){};
256+
257+
explicit FracMatrix(IntMatrix m);
258+
259+
/// Return the identity matrix of the specified dimension.
260+
static FracMatrix identity(unsigned dimension);
261+
262+
// Compute the determinant of the matrix (cubic time).
263+
// Stores the inverse of the matrix in the pointer
264+
// passed (if any). The pointer is unchanged if the inverse
265+
// does not exist, which happens iff det = 0.
266+
// Assert-fails if the matrix is not square.
267+
Fraction determinant(FracMatrix *inverse = nullptr) const;
242268
};
243269

244270
} // namespace presburger

mlir/lib/Analysis/Presburger/Matrix.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,4 +432,120 @@ MPInt IntMatrix::normalizeRow(unsigned row, unsigned cols) {
432432

433433
MPInt IntMatrix::normalizeRow(unsigned row) {
434434
return normalizeRow(row, getNumColumns());
435+
}
436+
437+
MPInt IntMatrix::determinant(IntMatrix *inverse) const {
438+
assert(nRows == nColumns &&
439+
"determinant can only be calculated for square matrices!");
440+
441+
FracMatrix m(*this);
442+
443+
FracMatrix fracInverse(nRows, nColumns);
444+
MPInt detM = m.determinant(&fracInverse).getAsInteger();
445+
446+
if (detM == 0)
447+
return MPInt(0);
448+
449+
*inverse = IntMatrix(nRows, nColumns);
450+
for (unsigned i = 0; i < nRows; i++)
451+
for (unsigned j = 0; j < nColumns; j++)
452+
inverse->at(i, j) = (fracInverse.at(i, j) * detM).getAsInteger();
453+
454+
return detM;
455+
}
456+
457+
FracMatrix FracMatrix::identity(unsigned dimension) {
458+
return Matrix::identity(dimension);
459+
}
460+
461+
FracMatrix::FracMatrix(IntMatrix m)
462+
: FracMatrix(m.getNumRows(), m.getNumColumns()) {
463+
for (unsigned i = 0; i < m.getNumRows(); i++)
464+
for (unsigned j = 0; j < m.getNumColumns(); j++)
465+
this->at(i, j) = m.at(i, j);
466+
}
467+
468+
Fraction FracMatrix::determinant(FracMatrix *inverse) const {
469+
assert(nRows == nColumns &&
470+
"determinant can only be calculated for square matrices!");
471+
472+
FracMatrix m(*this);
473+
FracMatrix tempInv(nRows, nColumns);
474+
if (inverse)
475+
tempInv = FracMatrix::identity(nRows);
476+
477+
Fraction a, b;
478+
// Make the matrix into upper triangular form using
479+
// gaussian elimination with row operations.
480+
// If inverse is required, we apply more operations
481+
// to turn the matrix into diagonal form. We apply
482+
// the same operations to the inverse matrix,
483+
// which is initially identity.
484+
// Either way, the product of the diagonal elements
485+
// is then the determinant.
486+
for (unsigned i = 0; i < nRows; i++) {
487+
if (m(i, i) == 0)
488+
// First ensure that the diagonal
489+
// element is nonzero, by swapping
490+
// it with a nonzero row.
491+
for (unsigned j = i + 1; j < nRows; j++) {
492+
if (m(j, i) != 0) {
493+
m.swapRows(j, i);
494+
if (inverse)
495+
tempInv.swapRows(j, i);
496+
break;
497+
}
498+
}
499+
500+
b = m.at(i, i);
501+
if (b == 0)
502+
return 0;
503+
504+
// Set all elements above the
505+
// diagonal to zero.
506+
if (inverse) {
507+
for (unsigned j = 0; j < i; j++) {
508+
if (m.at(j, i) == 0)
509+
continue;
510+
a = m.at(j, i);
511+
// Set element (j, i) to zero
512+
// by subtracting the ith row,
513+
// appropriately scaled.
514+
m.addToRow(i, j, -a / b);
515+
tempInv.addToRow(i, j, -a / b);
516+
}
517+
}
518+
519+
// Set all elements below the
520+
// diagonal to zero.
521+
for (unsigned j = i + 1; j < nRows; j++) {
522+
if (m.at(j, i) == 0)
523+
continue;
524+
a = m.at(j, i);
525+
// Set element (j, i) to zero
526+
// by subtracting the ith row,
527+
// appropriately scaled.
528+
m.addToRow(i, j, -a / b);
529+
if (inverse)
530+
tempInv.addToRow(i, j, -a / b);
531+
}
532+
}
533+
534+
// Now only diagonal elements of m are nonzero, but they are
535+
// not necessarily 1. To get the true inverse, we should
536+
// normalize them and apply the same scale to the inverse matrix.
537+
// For efficiency we skip scaling m and just scale tempInv appropriately.
538+
if (inverse) {
539+
for (unsigned i = 0; i < nRows; i++)
540+
for (unsigned j = 0; j < nRows; j++)
541+
tempInv.at(i, j) = tempInv.at(i, j) / m(i, i);
542+
543+
*inverse = std::move(tempInv);
544+
}
545+
546+
Fraction determinant = 1;
547+
for (unsigned i = 0; i < nRows; i++)
548+
determinant *= m.at(i, i);
549+
550+
return determinant;
435551
}

mlir/unittests/Analysis/Presburger/MatrixTest.cpp

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Analysis/Presburger/Matrix.h"
10-
#include "mlir/Analysis/Presburger/Fraction.h"
1110
#include "./Utils.h"
11+
#include "mlir/Analysis/Presburger/Fraction.h"
1212
#include <gmock/gmock.h>
1313
#include <gtest/gtest.h>
1414

@@ -210,7 +210,8 @@ TEST(MatrixTest, computeHermiteNormalForm) {
210210
{
211211
// Hermite form of a unimodular matrix is the identity matrix.
212212
IntMatrix mat = makeIntMatrix(3, 3, {{2, 3, 6}, {3, 2, 3}, {17, 11, 16}});
213-
IntMatrix hermiteForm = makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
213+
IntMatrix hermiteForm =
214+
makeIntMatrix(3, 3, {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
214215
checkHermiteNormalForm(mat, hermiteForm);
215216
}
216217

@@ -241,10 +242,71 @@ TEST(MatrixTest, computeHermiteNormalForm) {
241242
}
242243

243244
{
244-
IntMatrix mat =
245-
makeIntMatrix(3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
246-
IntMatrix hermiteForm =
247-
makeIntMatrix(3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
245+
IntMatrix mat = makeIntMatrix(
246+
3, 5, {{0, 2, 0, 7, 1}, {-1, 0, 0, -3, 0}, {0, 4, 1, 0, 8}});
247+
IntMatrix hermiteForm = makeIntMatrix(
248+
3, 5, {{1, 0, 0, 0, 0}, {0, 1, 0, 0, 0}, {0, 0, 1, 0, 0}});
248249
checkHermiteNormalForm(mat, hermiteForm);
249250
}
250251
}
252+
253+
TEST(MatrixTest, inverse) {
254+
FracMatrix mat = makeFracMatrix(
255+
2, 2, {{Fraction(2), Fraction(1)}, {Fraction(7), Fraction(0)}});
256+
FracMatrix inverse = makeFracMatrix(
257+
2, 2, {{Fraction(0), Fraction(1, 7)}, {Fraction(1), Fraction(-2, 7)}});
258+
259+
FracMatrix inv(2, 2);
260+
mat.determinant(&inv);
261+
262+
EXPECT_EQ_FRAC_MATRIX(inv, inverse);
263+
264+
mat = makeFracMatrix(
265+
2, 2, {{Fraction(0), Fraction(1)}, {Fraction(0), Fraction(2)}});
266+
Fraction det = mat.determinant(nullptr);
267+
268+
EXPECT_EQ(det, Fraction(0));
269+
270+
mat = makeFracMatrix(3, 3,
271+
{{Fraction(1), Fraction(2), Fraction(3)},
272+
{Fraction(4), Fraction(8), Fraction(6)},
273+
{Fraction(7), Fraction(8), Fraction(6)}});
274+
inverse = makeFracMatrix(3, 3,
275+
{{Fraction(0), Fraction(-1, 3), Fraction(1, 3)},
276+
{Fraction(-1, 2), Fraction(5, 12), Fraction(-1, 6)},
277+
{Fraction(2, 3), Fraction(-1, 6), Fraction(0)}});
278+
279+
mat.determinant(&inv);
280+
EXPECT_EQ_FRAC_MATRIX(inv, inverse);
281+
282+
mat = makeFracMatrix(0, 0, {});
283+
mat.determinant(&inv);
284+
}
285+
286+
TEST(MatrixTest, intInverse) {
287+
IntMatrix mat = makeIntMatrix(2, 2, {{2, 1}, {7, 0}});
288+
IntMatrix inverse = makeIntMatrix(2, 2, {{0, -1}, {-7, 2}});
289+
290+
IntMatrix inv(2, 2);
291+
mat.determinant(&inv);
292+
293+
EXPECT_EQ_INT_MATRIX(inv, inverse);
294+
295+
mat = makeIntMatrix(
296+
4, 4, {{4, 14, 11, 3}, {13, 5, 14, 12}, {13, 9, 7, 14}, {2, 3, 12, 7}});
297+
inverse = makeIntMatrix(4, 4,
298+
{{155, 1636, -579, -1713},
299+
{725, -743, 537, -111},
300+
{210, 735, -855, 360},
301+
{-715, -1409, 1401, 1482}});
302+
303+
mat.determinant(&inv);
304+
305+
EXPECT_EQ_INT_MATRIX(inv, inverse);
306+
307+
mat = makeIntMatrix(2, 2, {{0, 0}, {1, 2}});
308+
309+
MPInt det = mat.determinant(&inv);
310+
311+
EXPECT_EQ(det, 0);
312+
}

mlir/unittests/Analysis/Presburger/Utils.h

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
#define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_UTILS_H
1515

1616
#include "mlir/Analysis/Presburger/IntegerRelation.h"
17+
#include "mlir/Analysis/Presburger/Matrix.h"
1718
#include "mlir/Analysis/Presburger/PWMAFunction.h"
1819
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
1920
#include "mlir/Analysis/Presburger/Simplex.h"
20-
#include "mlir/Analysis/Presburger/Matrix.h"
2121
#include "mlir/IR/MLIRContext.h"
2222
#include "mlir/Support/LLVM.h"
2323

@@ -28,7 +28,7 @@ namespace mlir {
2828
namespace presburger {
2929

3030
inline IntMatrix makeIntMatrix(unsigned numRow, unsigned numColumns,
31-
ArrayRef<SmallVector<int, 8>> matrix) {
31+
ArrayRef<SmallVector<int, 8>> matrix) {
3232
IntMatrix results(numRow, numColumns);
3333
assert(matrix.size() == numRow);
3434
for (unsigned i = 0; i < numRow; ++i) {
@@ -40,9 +40,9 @@ inline IntMatrix makeIntMatrix(unsigned numRow, unsigned numColumns,
4040
return results;
4141
}
4242

43-
inline Matrix<Fraction> makeFracMatrix(unsigned numRow, unsigned numColumns,
44-
ArrayRef<SmallVector<Fraction, 8>> matrix) {
45-
Matrix<Fraction> results(numRow, numColumns);
43+
inline FracMatrix makeFracMatrix(unsigned numRow, unsigned numColumns,
44+
ArrayRef<SmallVector<Fraction, 8>> matrix) {
45+
FracMatrix results(numRow, numColumns);
4646
assert(matrix.size() == numRow);
4747
for (unsigned i = 0; i < numRow; ++i) {
4848
assert(matrix[i].size() == numColumns &&
@@ -53,6 +53,24 @@ inline Matrix<Fraction> makeFracMatrix(unsigned numRow, unsigned numColumns,
5353
return results;
5454
}
5555

56+
inline void EXPECT_EQ_INT_MATRIX(IntMatrix a, IntMatrix b) {
57+
EXPECT_EQ(a.getNumRows(), b.getNumRows());
58+
EXPECT_EQ(a.getNumColumns(), b.getNumColumns());
59+
60+
for (unsigned row = 0; row < a.getNumRows(); row++)
61+
for (unsigned col = 0; col < a.getNumColumns(); col++)
62+
EXPECT_EQ(a(row, col), b(row, col));
63+
}
64+
65+
inline void EXPECT_EQ_FRAC_MATRIX(FracMatrix a, FracMatrix b) {
66+
EXPECT_EQ(a.getNumRows(), b.getNumRows());
67+
EXPECT_EQ(a.getNumColumns(), b.getNumColumns());
68+
69+
for (unsigned row = 0; row < a.getNumRows(); row++)
70+
for (unsigned col = 0; col < a.getNumColumns(); col++)
71+
EXPECT_EQ(a(row, col), b(row, col));
72+
}
73+
5674
/// lhs and rhs represent non-negative integers or positive infinity. The
5775
/// infinity case corresponds to when the Optional is empty.
5876
inline bool infinityOrUInt64LE(std::optional<MPInt> lhs,

0 commit comments

Comments
 (0)