Skip to content

Commit 08543a5

Browse files
committed
[MLIR][Presburger] Introduce SimplexRollbackScopeExit to rollback on scope exit
This simplifies many places where we just want to do something in a "transient context" and return some value. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D122172
1 parent 93b9f50 commit 08543a5

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

mlir/include/mlir/Analysis/Presburger/Simplex.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,23 @@ class Simplex : public SimplexBase {
664664
void reduceBasis(Matrix &basis, unsigned level);
665665
};
666666

667+
/// Takes a snapshot of the simplex state on construction and rolls back to the
668+
/// snapshot on destruction.
669+
///
670+
/// Useful for performing operations in a "transient context", all changes from
671+
/// which get rolled back on scope exit.
672+
class SimplexRollbackScopeExit {
673+
public:
674+
SimplexRollbackScopeExit(Simplex &simplex) : simplex(simplex) {
675+
snapshot = simplex.getSnapshot();
676+
};
677+
~SimplexRollbackScopeExit() { simplex.rollback(snapshot); }
678+
679+
private:
680+
SimplexBase &simplex;
681+
unsigned snapshot;
682+
};
683+
667684
} // namespace presburger
668685
} // namespace mlir
669686

mlir/lib/Analysis/Presburger/PresburgerRelation.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,11 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
232232
// inequality, s_{i,j+1}. This function recurses into the next level i + 1
233233
// with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
234234
auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
235-
size_t snapshot = simplex.getSnapshot();
235+
SimplexRollbackScopeExit scopeExit(simplex);
236236
b.addInequality(ineq);
237237
simplex.addInequality(ineq);
238238
subtractRecursively(b, simplex, s, i + 1, result);
239239
b.removeInequality(b.getNumInequalities() - 1);
240-
simplex.rollback(snapshot);
241240
};
242241

243242
// For each inequality ineq, we first recurse with the part where ineq
@@ -519,16 +518,11 @@ PresburgerRelation SetCoalescer::coalesce() {
519518
/// that all inequalities of `cuttingIneqsB` are redundant for the facet of
520519
/// `simp` where `ineq` holds as an equality is contained within `a`.
521520
bool SetCoalescer::isFacetContained(ArrayRef<int64_t> ineq, Simplex &simp) {
522-
unsigned snapshot = simp.getSnapshot();
521+
SimplexRollbackScopeExit scopeExit(simp);
523522
simp.addEquality(ineq);
524-
if (llvm::any_of(cuttingIneqsB, [&simp](ArrayRef<int64_t> curr) {
525-
return !simp.isRedundantInequality(curr);
526-
})) {
527-
simp.rollback(snapshot);
528-
return false;
529-
}
530-
simp.rollback(snapshot);
531-
return true;
523+
return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef<int64_t> curr) {
524+
return simp.isRedundantInequality(curr);
525+
});
532526
}
533527

534528
void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j,

mlir/lib/Analysis/Presburger/Simplex.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -888,11 +888,11 @@ MaybeOptimum<Fraction> Simplex::computeOptimum(Direction direction,
888888
ArrayRef<int64_t> coeffs) {
889889
if (empty)
890890
return OptimumKind::Empty;
891-
unsigned snapshot = getSnapshot();
891+
892+
SimplexRollbackScopeExit scopeExit(*this);
892893
unsigned conIndex = addRow(coeffs);
893894
unsigned row = con[conIndex].pos;
894895
MaybeOptimum<Fraction> optimum = computeRowOptimum(direction, row);
895-
rollback(snapshot);
896896
return optimum;
897897
}
898898

@@ -1205,7 +1205,7 @@ class presburger::GBRSimplex {
12051205
// tableau before returning. We instead add a row for the objective function
12061206
// ourselves, call into computeOptimum, compute the duals from the tableau
12071207
// state, and finally rollback the addition of the row before returning.
1208-
unsigned snap = simplex.getSnapshot();
1208+
SimplexRollbackScopeExit scopeExit(simplex);
12091209
unsigned conIndex = simplex.addRow(getCoeffsForDirection(dir));
12101210
unsigned row = simplex.con[conIndex].pos;
12111211
MaybeOptimum<Fraction> maybeWidth =
@@ -1248,7 +1248,6 @@ class presburger::GBRSimplex {
12481248
else
12491249
dual.push_back(0);
12501250
}
1251-
simplex.rollback(snap);
12521251
return *maybeWidth;
12531252
}
12541253

0 commit comments

Comments
 (0)