diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -664,6 +664,23 @@ void reduceBasis(Matrix &basis, unsigned level); }; +/// Takes a snapshot of the simplex state on construction and rolls back to the +/// snapshot on destruction. +/// +/// Useful for performing operations in a "transient context", all changes from +/// which get rolled back on scope exit. +class SimplexRollbackScopeExit { +public: + SimplexRollbackScopeExit(Simplex &simplex) : simplex(simplex) { + snapshot = simplex.getSnapshot(); + }; + ~SimplexRollbackScopeExit() { simplex.rollback(snapshot); } + +private: + SimplexBase &simplex; + unsigned snapshot; +}; + } // namespace presburger } // namespace mlir diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -232,12 +232,11 @@ // inequality, s_{i,j+1}. This function recurses into the next level i + 1 // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}. auto recurseWithInequality = [&, i](ArrayRef ineq) { - size_t snapshot = simplex.getSnapshot(); + SimplexRollbackScopeExit scopeExit(simplex); b.addInequality(ineq); simplex.addInequality(ineq); subtractRecursively(b, simplex, s, i + 1, result); b.removeInequality(b.getNumInequalities() - 1); - simplex.rollback(snapshot); }; // For each inequality ineq, we first recurse with the part where ineq @@ -519,16 +518,11 @@ /// that all inequalities of `cuttingIneqsB` are redundant for the facet of /// `simp` where `ineq` holds as an equality is contained within `a`. bool SetCoalescer::isFacetContained(ArrayRef ineq, Simplex &simp) { - unsigned snapshot = simp.getSnapshot(); + SimplexRollbackScopeExit scopeExit(simp); simp.addEquality(ineq); - if (llvm::any_of(cuttingIneqsB, [&simp](ArrayRef curr) { - return !simp.isRedundantInequality(curr); - })) { - simp.rollback(snapshot); - return false; - } - simp.rollback(snapshot); - return true; + return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef curr) { + return simp.isRedundantInequality(curr); + }); } void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j, diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -888,11 +888,11 @@ ArrayRef coeffs) { if (empty) return OptimumKind::Empty; - unsigned snapshot = getSnapshot(); + + SimplexRollbackScopeExit scopeExit(*this); unsigned conIndex = addRow(coeffs); unsigned row = con[conIndex].pos; MaybeOptimum optimum = computeRowOptimum(direction, row); - rollback(snapshot); return optimum; } @@ -1205,7 +1205,7 @@ // tableau before returning. We instead add a row for the objective function // ourselves, call into computeOptimum, compute the duals from the tableau // state, and finally rollback the addition of the row before returning. - unsigned snap = simplex.getSnapshot(); + SimplexRollbackScopeExit scopeExit(simplex); unsigned conIndex = simplex.addRow(getCoeffsForDirection(dir)); unsigned row = simplex.con[conIndex].pos; MaybeOptimum maybeWidth = @@ -1248,7 +1248,6 @@ else dual.push_back(0); } - simplex.rollback(snap); return *maybeWidth; }