Skip to content

Commit cfd51fb

Browse files
[MLIR][Presburger] Add LLL basis reduction (#75565)
Add a method for LLL basis reduction to the FracMatrix class. This needs an abs() method for Fractions, which is added to Fraction.h.
1 parent ea43c8e commit cfd51fb

File tree

4 files changed

+133
-1
lines changed

4 files changed

+133
-1
lines changed

mlir/include/mlir/Analysis/Presburger/Fraction.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ inline bool operator>=(const Fraction &x, const Fraction &y) {
101101
return compare(x, y) >= 0;
102102
}
103103

104+
inline Fraction abs(const Fraction &f) {
105+
assert(f.den > 0 && "denominator of fraction must be positive!");
106+
return Fraction(abs(f.num), f.den);
107+
}
108+
104109
inline Fraction reduce(const Fraction &f) {
105110
if (f == Fraction(0))
106111
return Fraction(0, 1);
@@ -124,6 +129,12 @@ inline Fraction operator-(const Fraction &x, const Fraction &y) {
124129
return reduce(Fraction(x.num * y.den - x.den * y.num, x.den * y.den));
125130
}
126131

132+
// Find the integer nearest to a given fraction.
133+
inline MPInt round(const Fraction &f) {
134+
MPInt rem = f.num % f.den;
135+
return (f.num / f.den) + (rem > f.den / 2);
136+
}
137+
127138
inline Fraction &operator+=(Fraction &x, const Fraction &y) {
128139
x = x + y;
129140
return x;

mlir/include/mlir/Analysis/Presburger/Matrix.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ class FracMatrix : public Matrix<Fraction> {
270270
// of the rows of matrix (cubic time).
271271
// The rows of the matrix must be linearly independent.
272272
FracMatrix gramSchmidt() const;
273+
274+
// Run LLL basis reduction on the matrix, modifying it in-place.
275+
// The parameter is what [the original
276+
// paper](https://www.cs.cmu.edu/~avrim/451f11/lectures/lect1129_LLL.pdf)
277+
// calls `y`, usually 3/4.
278+
void LLL(Fraction delta);
273279
};
274280

275281
} // namespace presburger

mlir/lib/Analysis/Presburger/Matrix.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,4 +576,71 @@ FracMatrix FracMatrix::gramSchmidt() const {
576576
}
577577
}
578578
return orth;
579-
}
579+
}
580+
581+
// Convert the matrix, interpreted (row-wise) as a basis
582+
// to an LLL-reduced basis.
583+
//
584+
// This is an implementation of the algorithm described in
585+
// "Factoring polynomials with rational coefficients" by
586+
// A. K. Lenstra, H. W. Lenstra Jr., L. Lovasz.
587+
//
588+
// Let {b_1, ..., b_n} be the current basis and
589+
// {b_1*, ..., b_n*} be the Gram-Schmidt orthogonalised
590+
// basis (unnormalized).
591+
// Define the Gram-Schmidt coefficients μ_ij as
592+
// (b_i • b_j*) / (b_j* • b_j*), where (•) represents the inner product.
593+
//
594+
// We iterate starting from the second row to the last row.
595+
//
596+
// For the kth row, we first check μ_kj for all rows j < k.
597+
// We subtract b_j (scaled by the integer nearest to μ_kj)
598+
// from b_k.
599+
//
600+
// Now, we update k.
601+
// If b_k and b_{k-1} satisfy the Lovasz condition
602+
// |b_k|^2 ≥ (δ - μ_k{k-1}^2) |b_{k-1}|^2,
603+
// we are done and we increment k.
604+
// Otherwise, we swap b_k and b_{k-1} and decrement k.
605+
//
606+
// We repeat this until k = n and return.
607+
void FracMatrix::LLL(Fraction delta) {
608+
MPInt nearest;
609+
Fraction mu;
610+
611+
// `gsOrth` holds the Gram-Schmidt orthogonalisation
612+
// of the matrix at all times. It is recomputed every
613+
// time the matrix is modified during the algorithm.
614+
// This is naive and can be optimised.
615+
FracMatrix gsOrth = gramSchmidt();
616+
617+
// We start from the second row.
618+
unsigned k = 1;
619+
while (k < getNumRows()) {
620+
for (unsigned j = k - 1; j < k; j--) {
621+
// Compute the Gram-Schmidt coefficient μ_jk.
622+
mu = dotProduct(getRow(k), gsOrth.getRow(j)) /
623+
dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
624+
nearest = round(mu);
625+
// Subtract b_j scaled by the integer nearest to μ_jk from b_k.
626+
addToRow(k, getRow(j), -Fraction(nearest, 1));
627+
gsOrth = gramSchmidt(); // Update orthogonalization.
628+
}
629+
mu = dotProduct(getRow(k), gsOrth.getRow(k - 1)) /
630+
dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1));
631+
// Check the Lovasz condition for b_k and b_{k-1}.
632+
if (dotProduct(gsOrth.getRow(k), gsOrth.getRow(k)) >
633+
(delta - mu * mu) *
634+
dotProduct(gsOrth.getRow(k - 1), gsOrth.getRow(k - 1))) {
635+
// If it is satisfied, proceed to the next k.
636+
k += 1;
637+
} else {
638+
// If it is not satisfied, decrement k (without
639+
// going beyond the second row).
640+
swapRows(k, k - 1);
641+
gsOrth = gramSchmidt(); // Update orthogonalization.
642+
k = k > 1 ? k - 1 : 1;
643+
}
644+
}
645+
return;
646+
}

mlir/unittests/Analysis/Presburger/MatrixTest.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,4 +377,52 @@ TEST(MatrixTest, gramSchmidt) {
377377
gs = mat.gramSchmidt();
378378

379379
EXPECT_EQ_FRAC_MATRIX(gs, FracMatrix::identity(10));
380+
}
381+
382+
void checkReducedBasis(FracMatrix mat, Fraction delta) {
383+
FracMatrix gsOrth = mat.gramSchmidt();
384+
385+
// Size-reduced check.
386+
for (unsigned i = 0, e = mat.getNumRows(); i < e; i++) {
387+
for (unsigned j = 0; j < i; j++) {
388+
Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(j)) /
389+
dotProduct(gsOrth.getRow(j), gsOrth.getRow(j));
390+
EXPECT_TRUE(abs(mu) <= Fraction(1, 2));
391+
}
392+
}
393+
394+
// Lovasz condition check.
395+
for (unsigned i = 1, e = mat.getNumRows(); i < e; i++) {
396+
Fraction mu = dotProduct(mat.getRow(i), gsOrth.getRow(i - 1)) /
397+
dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1));
398+
EXPECT_TRUE(dotProduct(mat.getRow(i), mat.getRow(i)) >
399+
(delta - mu * mu) *
400+
dotProduct(gsOrth.getRow(i - 1), gsOrth.getRow(i - 1)));
401+
}
402+
}
403+
404+
TEST(MatrixTest, LLL) {
405+
FracMatrix mat =
406+
makeFracMatrix(3, 3,
407+
{{Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)},
408+
{Fraction(-1, 1), Fraction(0, 1), Fraction(2, 1)},
409+
{Fraction(3, 1), Fraction(5, 1), Fraction(6, 1)}});
410+
mat.LLL(Fraction(3, 4));
411+
412+
checkReducedBasis(mat, Fraction(3, 4));
413+
414+
mat = makeFracMatrix(
415+
2, 2,
416+
{{Fraction(12, 1), Fraction(2, 1)}, {Fraction(13, 1), Fraction(4, 1)}});
417+
mat.LLL(Fraction(3, 4));
418+
419+
checkReducedBasis(mat, Fraction(3, 4));
420+
421+
mat = makeFracMatrix(3, 3,
422+
{{Fraction(1, 1), Fraction(0, 1), Fraction(2, 1)},
423+
{Fraction(0, 1), Fraction(1, 3), -Fraction(5, 3)},
424+
{Fraction(0, 1), Fraction(0, 1), Fraction(1, 1)}});
425+
mat.LLL(Fraction(3, 4));
426+
427+
checkReducedBasis(mat, Fraction(3, 4));
380428
}

0 commit comments

Comments
 (0)