Skip to content

Commit 39b9395

Browse files
authored
[MLIR][Presburger] Add simplify function (#69107)
Added the simplify function to reduce the size of the constraint system, referencing the ISL implementation. Tested it on a simple Benchmark implemented by myself, calling SImplify before the operation and calling Simplify on the result after Subtract were tested, respectively. The Benchmark can be found here: [benchmark](https://github.com/gilsaia/llvm-project-test-fpl/blob/develop_benchmark/mlir/benchmark/presburger/Benchmark.cpp) For the case of calling Simplify before each operation, the overall result is shown in the following figure. ![image](https://github.com/llvm/llvm-project/assets/38588948/7099286e-b9a2-42e0-bc2a-1ed6627ead00) A comparison of the constraint system sizes and time for each operation is as follows ![image](https://github.com/llvm/llvm-project/assets/38588948/e5d0e488-f76e-4438-b19e-f6163699c526) ![image](https://github.com/llvm/llvm-project/assets/38588948/119a08de-4ee1-4cde-886c-50a91b502d93) ![image](https://github.com/llvm/llvm-project/assets/38588948/7a8b69ac-6cdb-41ab-9a75-cd016664fa5a) ![image](https://github.com/llvm/llvm-project/assets/38588948/c84b6eb1-62dc-4bae-a771-67d97ebf514a) ![image](https://github.com/llvm/llvm-project/assets/38588948/cdbfa3ed-0155-481e-9273-9d6dba3a2d7b) ![image](https://github.com/llvm/llvm-project/assets/38588948/8c945cff-a0a4-472a-a178-6b6a70a1b16a) ![image](https://github.com/llvm/llvm-project/assets/38588948/0bfe3a2b-3568-4d31-bebf-bd1b3c4e734e) ![image](https://github.com/llvm/llvm-project/assets/38588948/f1a99d56-edf5-45de-a506-512c0584f1d8) ![image](https://github.com/llvm/llvm-project/assets/38588948/ffef3312-6c99-494c-bb52-73aa8df275bb) ![image](https://github.com/llvm/llvm-project/assets/38588948/3e5924a7-8e1f-49d1-bd27-02a2e10a5cc4) ![image](https://github.com/llvm/llvm-project/assets/38588948/cec8be0e-dd19-46fa-88b4-2585d4031c9e) ![image](https://github.com/llvm/llvm-project/assets/38588948/3cb68e89-82c7-4cd2-b6bc-70f15e495ce8) For the case of calling Simplify on the result after Subtract, the overall results are as follows ![image](https://github.com/llvm/llvm-project/assets/38588948/be5b9c50-7417-42c8-abbf-8a50f093c3f5) A comparison of the constraint system sizes and time for subtract is as follows ![image](https://github.com/llvm/llvm-project/assets/38588948/fafe10ba-f8bd-43cd-b281-aaebf09af0af) ![image](https://github.com/llvm/llvm-project/assets/38588948/24662b40-42fc-47ee-a0a3-1b8b8f5778d2)
1 parent d4b8572 commit 39b9395

File tree

4 files changed

+200
-20
lines changed

4 files changed

+200
-20
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,15 @@ class IntegerRelation {
9191
return IntegerRelation(space);
9292
}
9393

94+
/// Return an empty system containing an invalid equation 0 = 1.
95+
static IntegerRelation getEmpty(const PresburgerSpace &space) {
96+
IntegerRelation result(0, 1, space.getNumVars() + 1, space);
97+
SmallVector<int64_t> invalidEq(space.getNumVars() + 1, 0);
98+
invalidEq.back() = 1;
99+
result.addEquality(invalidEq);
100+
return result;
101+
}
102+
94103
/// Return the kind of this IntegerRelation.
95104
virtual Kind getKind() const { return Kind::IntegerRelation; }
96105

@@ -138,7 +147,7 @@ class IntegerRelation {
138147
/// returns false. The equality check is performed in a plain manner, by
139148
/// comparing if all the equalities and inequalities in `this` and `other`
140149
/// are the same.
141-
bool isPlainEqual(const IntegerRelation &other) const;
150+
bool isObviouslyEqual(const IntegerRelation &other) const;
142151

143152
/// Return whether this is a subset of the given IntegerRelation. This is
144153
/// integer-exact and somewhat expensive, since it uses the integer emptiness
@@ -351,6 +360,9 @@ class IntegerRelation {
351360
/// Returns false otherwise.
352361
bool isEmpty() const;
353362

363+
/// Performs GCD checks and invalid constraint checks.
364+
bool isObviouslyEmpty() const;
365+
354366
/// Runs the GCD test on all equality constraints. Returns true if this test
355367
/// fails on any equality. Returns false otherwise.
356368
/// This test can be used to disprove the existence of a solution. If it
@@ -545,6 +557,10 @@ class IntegerRelation {
545557

546558
void removeDuplicateDivs();
547559

560+
/// Simplify the constraint system by removing canonicalizing constraints and
561+
/// removing redundant constraints.
562+
void simplify();
563+
548564
/// Converts variables of kind srcKind in the range [varStart, varLimit) to
549565
/// variables of kind dstKind. If `pos` is given, the variables are placed at
550566
/// position `pos` of dstKind, otherwise they are placed after all the other
@@ -724,6 +740,10 @@ class IntegerRelation {
724740
/// Returns the number of variables eliminated.
725741
unsigned gaussianEliminateVars(unsigned posStart, unsigned posLimit);
726742

743+
/// Perform a Gaussian elimination operation to reduce all equations to
744+
/// standard form. Returns whether the constraint system was modified.
745+
bool gaussianEliminate();
746+
727747
/// Eliminates the variable at the specified position using Fourier-Motzkin
728748
/// variable elimination, but uses Gaussian elimination if there is an
729749
/// equality involving that variable. If the result of the elimination is
@@ -755,6 +775,10 @@ class IntegerRelation {
755775
/// equalities.
756776
bool isColZero(unsigned pos) const;
757777

778+
/// Checks for identical inequalities and eliminates redundant inequalities.
779+
/// Returns whether the constraint system was modified.
780+
bool removeDuplicateConstraints();
781+
758782
/// Returns false if the fields corresponding to various variable counts, or
759783
/// equality/inequality buffer sizes aren't consistent; true otherwise. This
760784
/// is meant to be used within an assert internally.

mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,17 @@ class PresburgerRelation {
169169
bool isIntegerEmpty() const;
170170

171171
/// Return true if there is no disjunct, false otherwise.
172-
bool isPlainEmpty() const;
172+
bool isObviouslyEmpty() const;
173173

174174
/// Return true if the set is known to have one unconstrained disjunct, false
175175
/// otherwise.
176-
bool isPlainUniverse() const;
176+
bool isObviouslyUniverse() const;
177177

178178
/// Perform a quick equality check on `this` and `other`. The relations are
179179
/// equal if the check return true, but may or may not be equal if the check
180180
/// returns false. This is doing by directly comparing whether each internal
181181
/// disjunct is the same.
182-
bool isPlainEqual(const PresburgerRelation &set) const;
182+
bool isObviouslyEqual(const PresburgerRelation &set) const;
183183

184184
/// Return true if the set is consist of a single disjunct, without any local
185185
/// variables, false otherwise.
@@ -213,6 +213,10 @@ class PresburgerRelation {
213213
/// also be a union of convex disjuncts.
214214
PresburgerRelation computeReprWithOnlyDivLocals() const;
215215

216+
/// Simplify each disjunct, canonicalizing each disjunct and removing
217+
/// redundencies.
218+
PresburgerRelation simplify() const;
219+
216220
/// Print the set's internal state.
217221
void print(raw_ostream &os) const;
218222
void dump() const;

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ bool IntegerRelation::isEqual(const IntegerRelation &other) const {
8080
return PresburgerRelation(*this).isEqual(PresburgerRelation(other));
8181
}
8282

83-
bool IntegerRelation::isPlainEqual(const IntegerRelation &other) const {
83+
bool IntegerRelation::isObviouslyEqual(const IntegerRelation &other) const {
8484
if (!space.isEqual(other.getSpace()))
8585
return false;
8686
if (getNumEqualities() != other.getNumEqualities())
@@ -701,6 +701,12 @@ bool IntegerRelation::isEmpty() const {
701701
return false;
702702
}
703703

704+
bool IntegerRelation::isObviouslyEmpty() const {
705+
if (isEmptyByGCDTest() || hasInvalidConstraint())
706+
return true;
707+
return false;
708+
}
709+
704710
// Runs the GCD test on all equality constraints. Returns 'true' if this test
705711
// fails on any equality. Returns 'false' otherwise.
706712
// This test can be used to disprove the existence of a solution. If it returns
@@ -1079,6 +1085,57 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
10791085
return posLimit - posStart;
10801086
}
10811087

1088+
bool IntegerRelation::gaussianEliminate() {
1089+
gcdTightenInequalities();
1090+
unsigned firstVar = 0, vars = getNumVars();
1091+
unsigned nowDone, eqs, pivotRow;
1092+
for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) {
1093+
// Finds the first non-empty column.
1094+
for (; firstVar < vars; ++firstVar) {
1095+
if (!findConstraintWithNonZeroAt(firstVar, true, &pivotRow))
1096+
continue;
1097+
break;
1098+
}
1099+
// The matrix has been normalized to row echelon form.
1100+
if (firstVar >= vars)
1101+
break;
1102+
1103+
// The first pivot row found is below where it should currently be placed.
1104+
if (pivotRow > nowDone) {
1105+
equalities.swapRows(pivotRow, nowDone);
1106+
pivotRow = nowDone;
1107+
}
1108+
1109+
// Normalize all lower equations and all inequalities.
1110+
for (unsigned i = nowDone + 1; i < eqs; ++i) {
1111+
eliminateFromConstraint(this, i, pivotRow, firstVar, 0, true);
1112+
equalities.normalizeRow(i);
1113+
}
1114+
for (unsigned i = 0, ineqs = getNumInequalities(); i < ineqs; ++i) {
1115+
eliminateFromConstraint(this, i, pivotRow, firstVar, 0, false);
1116+
inequalities.normalizeRow(i);
1117+
}
1118+
gcdTightenInequalities();
1119+
}
1120+
1121+
// No redundant rows.
1122+
if (nowDone == eqs)
1123+
return false;
1124+
1125+
// Check to see if the redundant rows constant is zero, a non-zero value means
1126+
// the set is empty.
1127+
for (unsigned i = nowDone; i < eqs; ++i) {
1128+
if (atEq(i, vars) == 0)
1129+
continue;
1130+
1131+
*this = getEmpty(getSpace());
1132+
return true;
1133+
}
1134+
// Eliminate rows that are confined to be all zeros.
1135+
removeEqualityRange(nowDone, eqs);
1136+
return true;
1137+
}
1138+
10821139
// A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
10831140
// to check if a constraint is redundant.
10841141
void IntegerRelation::removeRedundantInequalities() {
@@ -1269,6 +1326,21 @@ void IntegerRelation::removeDuplicateDivs() {
12691326
divs.removeDuplicateDivs(merge);
12701327
}
12711328

1329+
void IntegerRelation::simplify() {
1330+
bool changed = true;
1331+
// Repeat until we reach a fixed point.
1332+
while (changed) {
1333+
if (isObviouslyEmpty())
1334+
return;
1335+
changed = false;
1336+
normalizeConstraintsByGCD();
1337+
changed |= gaussianEliminate();
1338+
changed |= removeDuplicateConstraints();
1339+
}
1340+
// Current set is not empty.
1341+
return;
1342+
}
1343+
12721344
/// Removes local variables using equalities. Each equality is checked if it
12731345
/// can be reduced to the form: `e = affine-expr`, where `e` is a local
12741346
/// variable and `affine-expr` is an affine expression not containing `e`.
@@ -2216,6 +2288,68 @@ IntegerPolyhedron IntegerRelation::getDomainSet() const {
22162288
return IntegerPolyhedron(std::move(copyRel));
22172289
}
22182290

2291+
bool IntegerRelation::removeDuplicateConstraints() {
2292+
bool changed = false;
2293+
SmallDenseMap<ArrayRef<MPInt>, unsigned> hashTable;
2294+
unsigned ineqs = getNumInequalities(), cols = getNumCols();
2295+
2296+
if (ineqs <= 1)
2297+
return changed;
2298+
2299+
// Check if the non-constant part of the constraint is the same.
2300+
ArrayRef<MPInt> row = getInequality(0).drop_back();
2301+
hashTable.insert({row, 0});
2302+
for (unsigned k = 1; k < ineqs; ++k) {
2303+
row = getInequality(k).drop_back();
2304+
if (!hashTable.contains(row)) {
2305+
hashTable.insert({row, k});
2306+
continue;
2307+
}
2308+
2309+
// For identical cases, keep only the smaller part of the constant term.
2310+
unsigned l = hashTable[row];
2311+
changed = true;
2312+
if (atIneq(k, cols - 1) <= atIneq(l, cols - 1))
2313+
inequalities.swapRows(k, l);
2314+
removeInequality(k);
2315+
--k;
2316+
--ineqs;
2317+
}
2318+
2319+
// Check the neg form of each inequality, need an extra vector to store it.
2320+
SmallVector<MPInt> negIneq(cols - 1);
2321+
for (unsigned k = 0; k < ineqs; ++k) {
2322+
row = getInequality(k).drop_back();
2323+
negIneq.assign(row.begin(), row.end());
2324+
for (MPInt &ele : negIneq)
2325+
ele = -ele;
2326+
if (!hashTable.contains(negIneq))
2327+
continue;
2328+
2329+
// For cases where the neg is the same as other inequalities, check that the
2330+
// sum of their constant terms is positive.
2331+
unsigned l = hashTable[row];
2332+
auto sum = atIneq(l, cols - 1) + atIneq(k, cols - 1);
2333+
if (sum > 0 || l == k)
2334+
continue;
2335+
2336+
// A sum of constant terms equal to zero combines two inequalities into one
2337+
// equation, less than zero means the set is empty.
2338+
changed = true;
2339+
if (k < l)
2340+
std::swap(l, k);
2341+
if (sum == 0) {
2342+
addEquality(getInequality(k));
2343+
removeInequality(k);
2344+
removeInequality(l);
2345+
} else
2346+
*this = getEmpty(getSpace());
2347+
break;
2348+
}
2349+
2350+
return changed;
2351+
}
2352+
22192353
IntegerPolyhedron IntegerRelation::getRangeSet() const {
22202354
IntegerRelation copyRel = *this;
22212355

mlir/lib/Analysis/Presburger/PresburgerRelation.cpp

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,19 @@ void PresburgerRelation::unionInPlace(const IntegerRelation &disjunct) {
8383
void PresburgerRelation::unionInPlace(const PresburgerRelation &set) {
8484
assert(space.isCompatible(set.getSpace()) && "Spaces should match");
8585

86-
if (isPlainEqual(set))
86+
if (isObviouslyEqual(set))
8787
return;
8888

89-
if (isPlainEmpty()) {
89+
if (isObviouslyEmpty()) {
9090
disjuncts = set.disjuncts;
9191
return;
9292
}
93-
if (set.isPlainEmpty())
93+
if (set.isObviouslyEmpty())
9494
return;
9595

96-
if (isPlainUniverse())
96+
if (isObviouslyUniverse())
9797
return;
98-
if (set.isPlainUniverse()) {
98+
if (set.isObviouslyUniverse()) {
9999
disjuncts = set.disjuncts;
100100
return;
101101
}
@@ -144,10 +144,10 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
144144

145145
// If the set is empty or the other set is universe,
146146
// directly return the set
147-
if (isPlainEmpty() || set.isPlainUniverse())
147+
if (isObviouslyEmpty() || set.isObviouslyUniverse())
148148
return *this;
149149

150-
if (set.isPlainEmpty() || isPlainUniverse())
150+
if (set.isObviouslyEmpty() || isObviouslyUniverse())
151151
return set;
152152

153153
PresburgerRelation result(getSpace());
@@ -576,6 +576,9 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
576576
}
577577
}
578578

579+
// Try to simplify the results.
580+
result = result.simplify();
581+
579582
return result;
580583
}
581584

@@ -593,7 +596,7 @@ PresburgerRelation::subtract(const PresburgerRelation &set) const {
593596

594597
// If we know that the two sets are clearly equal, we can simply return the
595598
// empty set.
596-
if (isPlainEqual(set))
599+
if (isObviouslyEqual(set))
597600
return result;
598601

599602
// We compute (U_i t_i) \ (U_i set_i) as U_i (t_i \ V_i set_i).
@@ -615,7 +618,7 @@ bool PresburgerRelation::isEqual(const PresburgerRelation &set) const {
615618
return this->isSubsetOf(set) && set.isSubsetOf(*this);
616619
}
617620

618-
bool PresburgerRelation::isPlainEqual(const PresburgerRelation &set) const {
621+
bool PresburgerRelation::isObviouslyEqual(const PresburgerRelation &set) const {
619622
if (!space.isCompatible(set.getSpace()))
620623
return false;
621624

@@ -625,7 +628,7 @@ bool PresburgerRelation::isPlainEqual(const PresburgerRelation &set) const {
625628
// Compare each disjunct in this PresburgerRelation with the corresponding
626629
// disjunct in the other PresburgerRelation.
627630
for (unsigned int i = 0, n = getNumDisjuncts(); i < n; ++i) {
628-
if (!getDisjunct(i).isPlainEqual(set.getDisjunct(i)))
631+
if (!getDisjunct(i).isObviouslyEqual(set.getDisjunct(i)))
629632
return false;
630633
}
631634
return true;
@@ -635,18 +638,22 @@ bool PresburgerRelation::isPlainEqual(const PresburgerRelation &set) const {
635638
/// otherwise. It is a simple check that only check if the relation has at least
636639
/// one unconstrained disjunct, indicating the absence of constraints or
637640
/// conditions.
638-
bool PresburgerRelation::isPlainUniverse() const {
639-
return llvm::any_of(getAllDisjuncts(), [](const IntegerRelation &disjunct) {
640-
return disjunct.getNumConstraints() == 0;
641-
});
641+
bool PresburgerRelation::isObviouslyUniverse() const {
642+
for (const IntegerRelation &disjunct : getAllDisjuncts()) {
643+
if (disjunct.getNumConstraints() == 0)
644+
return true;
645+
}
646+
return false;
642647
}
643648

644649
bool PresburgerRelation::isConvexNoLocals() const {
645650
return getNumDisjuncts() == 1 && getSpace().getNumLocalVars() == 0;
646651
}
647652

648653
/// Return true if there is no disjunct, false otherwise.
649-
bool PresburgerRelation::isPlainEmpty() const { return getNumDisjuncts() == 0; }
654+
bool PresburgerRelation::isObviouslyEmpty() const {
655+
return getNumDisjuncts() == 0;
656+
}
650657

651658
/// Return true if all the sets in the union are known to be integer empty,
652659
/// false otherwise.
@@ -1015,6 +1022,17 @@ bool PresburgerRelation::hasOnlyDivLocals() const {
10151022
});
10161023
}
10171024

1025+
PresburgerRelation PresburgerRelation::simplify() const {
1026+
PresburgerRelation origin = *this;
1027+
PresburgerRelation result = PresburgerRelation(getSpace());
1028+
for (IntegerRelation &disjunct : origin.disjuncts) {
1029+
disjunct.simplify();
1030+
if (!disjunct.isObviouslyEmpty())
1031+
result.unionInPlace(disjunct);
1032+
}
1033+
return result;
1034+
}
1035+
10181036
void PresburgerRelation::print(raw_ostream &os) const {
10191037
os << "Number of Disjuncts: " << getNumDisjuncts() << "\n";
10201038
for (const IntegerRelation &disjunct : disjuncts) {

0 commit comments

Comments
 (0)