Skip to content

[MLIR][Presburger] Add Gram-Schmidt #70843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 13, 2023
5 changes: 5 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ class FracMatrix : public Matrix<Fraction> {
// does not exist, which happens iff det = 0.
// Assert-fails if the matrix is not square.
Fraction determinant(FracMatrix *inverse = nullptr) const;

// Computes the Gram-Schmidt orthogonalisation
// of the rows of matrix (cubic time).
// The rows of the matrix must be linearly independent.
FracMatrix gramSchmidt() const;
};

} // namespace presburger
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ SmallVector<MPInt, 8> getNegatedCoeffs(ArrayRef<MPInt> coeffs);
/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0,
/// since all the variables are constrained to be integers.
SmallVector<MPInt, 8> getComplementIneq(ArrayRef<MPInt> ineq);

/// Compute the dot product of two vectors.
/// The vectors must have the same sizes.
Fraction dotProduct(ArrayRef<Fraction> a, ArrayRef<Fraction> b);

} // namespace presburger
} // namespace mlir

Expand Down
26 changes: 24 additions & 2 deletions mlir/lib/Analysis/Presburger/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ FracMatrix FracMatrix::identity(unsigned dimension) {

FracMatrix::FracMatrix(IntMatrix m)
: FracMatrix(m.getNumRows(), m.getNumColumns()) {
for (unsigned i = 0; i < m.getNumRows(); i++)
for (unsigned j = 0; j < m.getNumColumns(); j++)
for (unsigned i = 0, r = m.getNumRows(); i < r; i++)
for (unsigned j = 0, c = m.getNumColumns(); j < c; j++)
this->at(i, j) = m.at(i, j);
}

Expand Down Expand Up @@ -548,4 +548,26 @@ Fraction FracMatrix::determinant(FracMatrix *inverse) const {
determinant *= m.at(i, i);

return determinant;
}

FracMatrix FracMatrix::gramSchmidt() const {
// Create a copy of the argument to store
// the orthogonalised version.
FracMatrix orth(*this);

// For each vector (row) in the matrix, subtract its unit
// projection along each of the previous vectors.
// This ensures that it has no component in the direction
// of any of the previous vectors.
for (unsigned i = 1, e = getNumRows(); i < e; i++) {
for (unsigned j = 0; j < i; j++) {
Fraction jNormSquared = dotProduct(orth.getRow(j), orth.getRow(j));
assert(jNormSquared != 0 && "some row became zero! Inputs to this "
"function must be linearly independent.");
Fraction projectionScale =
dotProduct(orth.getRow(i), orth.getRow(j)) / jNormSquared;
orth.addToRow(j, i, -projectionScale);
}
}
return orth;
}
9 changes: 9 additions & 0 deletions mlir/lib/Analysis/Presburger/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,12 @@ SmallVector<int64_t, 8> presburger::getInt64Vec(ArrayRef<MPInt> range) {
std::transform(range.begin(), range.end(), result.begin(), int64FromMPInt);
return result;
}

Fraction presburger::dotProduct(ArrayRef<Fraction> a, ArrayRef<Fraction> b) {
assert(a.size() == b.size() &&
"dot product is only valid for vectors of equal sizes!");
Fraction sum = 0;
for (unsigned i = 0, e = a.size(); i < e; i++)
sum += a[i] * b[i];
return sum;
}
68 changes: 68 additions & 0 deletions mlir/unittests/Analysis/Presburger/MatrixTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,71 @@ TEST(MatrixTest, intInverse) {

EXPECT_EQ(det, 0);
}

TEST(MatrixTest, gramSchmidt) {
FracMatrix mat =
makeFracMatrix(3, 5,
{{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1),
Fraction(12, 1), Fraction(19, 1)},
{Fraction(4, 1), Fraction(5, 1), Fraction(6, 1),
Fraction(13, 1), Fraction(20, 1)},
{Fraction(7, 1), Fraction(8, 1), Fraction(9, 1),
Fraction(16, 1), Fraction(24, 1)}});

FracMatrix gramSchmidt = makeFracMatrix(
3, 5,
{{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1), Fraction(12, 1),
Fraction(19, 1)},
{Fraction(142, 185), Fraction(383, 555), Fraction(68, 111),
Fraction(13, 185), Fraction(-262, 555)},
{Fraction(53, 463), Fraction(27, 463), Fraction(1, 463),
Fraction(-181, 463), Fraction(100, 463)}});

FracMatrix gs = mat.gramSchmidt();

EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
for (unsigned i = 0; i < 3u; i++)
for (unsigned j = i + 1; j < 3u; j++)
EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);

mat = makeFracMatrix(3, 3,
{{Fraction(20, 1), Fraction(17, 1), Fraction(10, 1)},
{Fraction(20, 1), Fraction(18, 1), Fraction(6, 1)},
{Fraction(15, 1), Fraction(14, 1), Fraction(10, 1)}});

gramSchmidt = makeFracMatrix(
3, 3,
{{Fraction(20, 1), Fraction(17, 1), Fraction(10, 1)},
{Fraction(460, 789), Fraction(1180, 789), Fraction(-2926, 789)},
{Fraction(-2925, 3221), Fraction(3000, 3221), Fraction(750, 3221)}});

gs = mat.gramSchmidt();

EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
for (unsigned i = 0; i < 3u; i++)
for (unsigned j = i + 1; j < 3u; j++)
EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);

mat = makeFracMatrix(
4, 4,
{{Fraction(1, 26), Fraction(13, 12), Fraction(34, 13), Fraction(7, 10)},
{Fraction(40, 23), Fraction(34, 1), Fraction(11, 19), Fraction(15, 1)},
{Fraction(21, 22), Fraction(10, 9), Fraction(4, 11), Fraction(14, 11)},
{Fraction(35, 22), Fraction(1, 15), Fraction(5, 8), Fraction(30, 1)}});

gs = mat.gramSchmidt();

// The integers involved are too big to construct the actual matrix.
// but we can check that the result is linearly independent.
ASSERT_FALSE(mat.determinant(nullptr) == 0);

for (unsigned i = 0; i < 4u; i++)
for (unsigned j = i + 1; j < 4u; j++)
EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);

mat = FracMatrix::identity(/*dim=*/10);

gs = mat.gramSchmidt();

EXPECT_EQ_FRAC_MATRIX(gs, FracMatrix::identity(10));
}