Skip to content

[MLIR][Presburger] Add Gram-Schmidt #70843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 13, 2023
Merged

Conversation

Abhinav271828
Copy link
Contributor

Implement Gram-Schmidt orthogonalisation for the FracMatrix class.
This requires dotProduct, which has been added as a util.

@Abhinav271828 Abhinav271828 changed the title [MLIR][Presburger] [MLIR][Presburger] Add Gram-Schmidt Oct 31, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2023

@llvm/pr-subscribers-mlir-presburger

@llvm/pr-subscribers-mlir

Author: None (Abhinav271828)

Changes

Implement Gram-Schmidt orthogonalisation for the FracMatrix class.
This requires dotProduct, which has been added as a util.


Full diff: https://github.com/llvm/llvm-project/pull/70843.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Analysis/Presburger/Matrix.h (+5)
  • (modified) mlir/include/mlir/Analysis/Presburger/Utils.h (+5)
  • (modified) mlir/lib/Analysis/Presburger/Matrix.cpp (+23)
  • (modified) mlir/lib/Analysis/Presburger/Utils.cpp (+8)
  • (modified) mlir/unittests/Analysis/Presburger/MatrixTest.cpp (+15)
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 4d9f13832e0692a..b591b0b4fdad167 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -265,6 +265,11 @@ class FracMatrix : public Matrix<Fraction> {
   // does not exist, which happens iff det = 0.
   // Assert-fails if the matrix is not square.
   Fraction determinant(FracMatrix *inverse = nullptr) const;
+
+  // Computes the Gram-Schmidt orthogonalisation
+  // of the matrix (cubic time).
+  FracMatrix gramSchmidt() const;
+
 };
 
 } // namespace presburger
diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index a451ae8bf55723e..639683cea0d9f2e 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -276,6 +276,11 @@ SmallVector<MPInt, 8> getNegatedCoeffs(ArrayRef<MPInt> coeffs);
 /// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0,
 /// since all the variables are constrained to be integers.
 SmallVector<MPInt, 8> getComplementIneq(ArrayRef<MPInt> ineq);
+
+/// Compute the dot product of two vectors.
+/// Assumes that the vectors have the same sizes.
+Fraction dotProduct(MutableArrayRef<Fraction> a, MutableArrayRef<Fraction> b);
+
 } // namespace presburger
 } // namespace mlir
 
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index ae97e456d9820cf..7669e461f079472 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -548,4 +548,27 @@ Fraction FracMatrix::determinant(FracMatrix *inverse) const {
     determinant *= m.at(i, i);
 
   return determinant;
+}
+
+FracMatrix FracMatrix::gramSchmidt() const {
+
+    // Create a copy of the argument to store
+    // the orthogonalised version.
+    FracMatrix orth(*this);
+    Fraction projectionScale;
+
+    // For each vector (row) in the matrix,
+    // subtract its unit projection along
+    // each of the previous vectors.
+    // This ensures that it has no component
+    // in the direction of any of the
+    // previous vectors.
+    for (unsigned i = 1; i < getNumRows(); i++) {
+        for (unsigned j = 0; j < i; j++) {
+            projectionScale = dotProduct(orth.getRow(i), orth.getRow(j)) /
+                              dotProduct(orth.getRow(j), orth.getRow(j));
+            orth.addToRow(j, i, -projectionScale);
+        }
+    }
+    return orth;
 }
\ No newline at end of file
diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index 9aef2f5de109357..30f01f545819101 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -520,3 +520,11 @@ SmallVector<int64_t, 8> presburger::getInt64Vec(ArrayRef<MPInt> range) {
   std::transform(range.begin(), range.end(), result.begin(), int64FromMPInt);
   return result;
 }
+
+Fraction presburger::dotProduct(MutableArrayRef<Fraction> a, MutableArrayRef<Fraction> b) {
+  assert(a.size() == b.size() && "Dot product of two unequal vectors!");
+  Fraction sum = 0;
+  for (unsigned i = 0; i < a.size(); i++)
+    sum += a[i] * b[i];
+  return sum;
+}
\ No newline at end of file
diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
index d05b05e004c5c5f..b7cc8e07a54e1ea 100644
--- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp
@@ -310,3 +310,18 @@ TEST(MatrixTest, intInverse) {
 
   EXPECT_EQ(det, 0);
 }
+
+TEST(MatrixTest, gramSchmidt) {
+  FracMatrix mat = makeFracMatrix(3, 5, {{Fraction(3, 1), Fraction(4, 1), Fraction(5, 1), Fraction(12, 1), Fraction(19, 1)},
+                                         {Fraction(4, 1), Fraction(5, 1), Fraction(6, 1), Fraction(13, 1), Fraction(20, 1)},
+                                         {Fraction(7, 1), Fraction(8, 1), Fraction(9, 1), Fraction(16, 1), Fraction(24, 1)}});
+
+  FracMatrix gramSchmidt = makeFracMatrix(3, 5,
+         {{Fraction(3, 1),     Fraction(4, 1),     Fraction(5, 1),    Fraction(12, 1),     Fraction(19, 1)},
+          {Fraction(142, 185), Fraction(383, 555), Fraction(68, 111), Fraction(13, 185),   Fraction(-262, 555)},
+          {Fraction(53, 463),  Fraction(27, 463),  Fraction(1, 463),  Fraction(-181, 463), Fraction(100, 463)}});
+
+  FracMatrix gs = mat.gramSchmidt();
+
+  EXPECT_EQ_FRAC_MATRIX(gs, gramSchmidt);
+}
\ No newline at end of file

Copy link

github-actions bot commented Oct 31, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@Superty Superty self-requested a review October 31, 2023 18:11
Copy link
Member

@Superty Superty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use slightly longer lines for comments, you can make use of the full 80 lines of space that are available. You can just write long lines and let the formatter wrap them. You don't have to worry about keeping them short.

@Superty
Copy link
Member

Superty commented Dec 8, 2023

I prefer an assert statement that clearly says what actually happened

"Some row became zero! Inputs to this function must be linearly independent."

@Superty Superty merged commit 84ab06b into llvm:main Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants