Skip to content

Commit 84ab06b

Browse files
[MLIR][Presburger] Add Gram-Schmidt (#70843)
Implement Gram-Schmidt orthogonalisation for the FracMatrix class. This requires dotProduct, which has been added as a util.
1 parent f64a057 commit 84ab06b

File tree

5 files changed

+111
-2
lines changed

5 files changed

+111
-2
lines changed

mlir/include/mlir/Analysis/Presburger/Matrix.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ class FracMatrix : public Matrix<Fraction> {
265265
// does not exist, which happens iff det = 0.
266266
// Assert-fails if the matrix is not square.
267267
Fraction determinant(FracMatrix *inverse = nullptr) const;
268+
269+
// Computes the Gram-Schmidt orthogonalisation
270+
// of the rows of matrix (cubic time).
271+
// The rows of the matrix must be linearly independent.
272+
FracMatrix gramSchmidt() const;
268273
};
269274

270275
} // namespace presburger

mlir/include/mlir/Analysis/Presburger/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ SmallVector<MPInt, 8> getNegatedCoeffs(ArrayRef<MPInt> coeffs);
276276
/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0,
277277
/// since all the variables are constrained to be integers.
278278
SmallVector<MPInt, 8> getComplementIneq(ArrayRef<MPInt> ineq);
279+
280+
/// Compute the dot product of two vectors.
281+
/// The vectors must have the same sizes.
282+
Fraction dotProduct(ArrayRef<Fraction> a, ArrayRef<Fraction> b);
283+
279284
} // namespace presburger
280285
} // namespace mlir
281286

mlir/lib/Analysis/Presburger/Matrix.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,8 @@ FracMatrix FracMatrix::identity(unsigned dimension) {
466466

467467
FracMatrix::FracMatrix(IntMatrix m)
468468
: FracMatrix(m.getNumRows(), m.getNumColumns()) {
469-
for (unsigned i = 0; i < m.getNumRows(); i++)
470-
for (unsigned j = 0; j < m.getNumColumns(); j++)
469+
for (unsigned i = 0, r = m.getNumRows(); i < r; i++)
470+
for (unsigned j = 0, c = m.getNumColumns(); j < c; j++)
471471
this->at(i, j) = m.at(i, j);
472472
}
473473

@@ -554,4 +554,26 @@ Fraction FracMatrix::determinant(FracMatrix *inverse) const {
554554
determinant *= m.at(i, i);
555555

556556
return determinant;
557+
}
558+
559+
FracMatrix FracMatrix::gramSchmidt() const {
560+
// Create a copy of the argument to store
561+
// the orthogonalised version.
562+
FracMatrix orth(*this);
563+
564+
// For each vector (row) in the matrix, subtract its unit
565+
// projection along each of the previous vectors.
566+
// This ensures that it has no component in the direction
567+
// of any of the previous vectors.
568+
for (unsigned i = 1, e = getNumRows(); i < e; i++) {
569+
for (unsigned j = 0; j < i; j++) {
570+
Fraction jNormSquared = dotProduct(orth.getRow(j), orth.getRow(j));
571+
assert(jNormSquared != 0 && "some row became zero! Inputs to this "
572+
"function must be linearly independent.");
573+
Fraction projectionScale =
574+
dotProduct(orth.getRow(i), orth.getRow(j)) / jNormSquared;
575+
orth.addToRow(j, i, -projectionScale);
576+
}
577+
}
578+
return orth;
557579
}

mlir/lib/Analysis/Presburger/Utils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,12 @@ SmallVector<int64_t, 8> presburger::getInt64Vec(ArrayRef<MPInt> range) {
529529
std::transform(range.begin(), range.end(), result.begin(), int64FromMPInt);
530530
return result;
531531
}
532+
533+
Fraction presburger::dotProduct(ArrayRef<Fraction> a, ArrayRef<Fraction> b) {
534+
assert(a.size() == b.size() &&
535+
"dot product is only valid for vectors of equal sizes!");
536+
Fraction sum = 0;
537+
for (unsigned i = 0, e = a.size(); i < e; i++)
538+
sum += a[i] * b[i];
539+
return sum;
540+
}

mlir/unittests/Analysis/Presburger/MatrixTest.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,71 @@ TEST(MatrixTest, intInverse) {
310310

311311
EXPECT_EQ(det, 0);
312312
}
313+
314+
TEST(MatrixTest, gramSchmidt) {
315+
FracMatrix mat =
316+
makeFracMatrix(3, 5,
317+
{{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1),
318+
Fraction(12, 1), Fraction(19, 1)},
319+
{Fraction(4, 1), Fraction(5, 1), Fraction(6, 1),
320+
Fraction(13, 1), Fraction(20, 1)},
321+
{Fraction(7, 1), Fraction(8, 1), Fraction(9, 1),
322+
Fraction(16, 1), Fraction(24, 1)}});
323+
324+
FracMatrix gramSchmidt = makeFracMatrix(
325+
3, 5,
326+
{{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1), Fraction(12, 1),
327+
Fraction(19, 1)},
328+
{Fraction(142, 185), Fraction(383, 555), Fraction(68, 111),
329+
Fraction(13, 185), Fraction(-262, 555)},
330+
{Fraction(53, 463), Fraction(27, 463), Fraction(1, 463),
331+
Fraction(-181, 463), Fraction(100, 463)}});
332+
333+
FracMatrix gs = mat.gramSchmidt();
334+
335+
EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
336+
for (unsigned i = 0; i < 3u; i++)
337+
for (unsigned j = i + 1; j < 3u; j++)
338+
EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
339+
340+
mat = makeFracMatrix(3, 3,
341+
{{Fraction(20, 1), Fraction(17, 1), Fraction(10, 1)},
342+
{Fraction(20, 1), Fraction(18, 1), Fraction(6, 1)},
343+
{Fraction(15, 1), Fraction(14, 1), Fraction(10, 1)}});
344+
345+
gramSchmidt = makeFracMatrix(
346+
3, 3,
347+
{{Fraction(20, 1), Fraction(17, 1), Fraction(10, 1)},
348+
{Fraction(460, 789), Fraction(1180, 789), Fraction(-2926, 789)},
349+
{Fraction(-2925, 3221), Fraction(3000, 3221), Fraction(750, 3221)}});
350+
351+
gs = mat.gramSchmidt();
352+
353+
EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
354+
for (unsigned i = 0; i < 3u; i++)
355+
for (unsigned j = i + 1; j < 3u; j++)
356+
EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
357+
358+
mat = makeFracMatrix(
359+
4, 4,
360+
{{Fraction(1, 26), Fraction(13, 12), Fraction(34, 13), Fraction(7, 10)},
361+
{Fraction(40, 23), Fraction(34, 1), Fraction(11, 19), Fraction(15, 1)},
362+
{Fraction(21, 22), Fraction(10, 9), Fraction(4, 11), Fraction(14, 11)},
363+
{Fraction(35, 22), Fraction(1, 15), Fraction(5, 8), Fraction(30, 1)}});
364+
365+
gs = mat.gramSchmidt();
366+
367+
// The integers involved are too big to construct the actual matrix.
368+
// but we can check that the result is linearly independent.
369+
ASSERT_FALSE(mat.determinant(nullptr) == 0);
370+
371+
for (unsigned i = 0; i < 4u; i++)
372+
for (unsigned j = i + 1; j < 4u; j++)
373+
EXPECT_EQ(dotProduct(gs.getRow(i), gs.getRow(j)), 0);
374+
375+
mat = FracMatrix::identity(/*dim=*/10);
376+
377+
gs = mat.gramSchmidt();
378+
379+
EXPECT_EQ_FRAC_MATRIX(gs, FracMatrix::identity(10));
380+
}

0 commit comments

Comments
 (0)