diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -144,6 +144,30 @@ return inequalities.getRow(idx); } + /// The struct Counts stores the count of each IdKind, and also of each + /// constraint type. getCounts() returns a Counts object describing the + /// current state of the IntegerRelation. truncate() truncates all ids of each + /// IdKind and all constraints of both kinds beyond the counts in the + /// specified Counts object. This can be used to achieve rudimentary + /// rollback support. As long as none of the existing constraints or ids are + /// disturbed, and only additional ids or constraints are added, this addition + /// can be rolled back using truncate. + struct Counts { + public: + Counts(const PresburgerLocalSpace &space, unsigned numIneqs, + unsigned numEqs) + : space(space), numIneqs(numIneqs), numEqs(numEqs) {} + const PresburgerLocalSpace &getSpace() const { return space; }; + unsigned getNumIneqs() const { return numIneqs; } + unsigned getNumEqs() const { return numEqs; } + + private: + PresburgerLocalSpace space; + unsigned numIneqs, numEqs; + }; + Counts getCounts() const; + void truncate(const Counts &counts); + /// Insert `num` identifiers of the specified kind at position `pos`. /// Positions are relative to the kind of identifier. The coefficient columns /// corresponding to the added identifiers are initialized to zero. Return the @@ -482,6 +506,12 @@ /// arrays as needed. void removeIdRange(unsigned idStart, unsigned idLimit); + /// Truncate the ids of the specified kind to the specified count by dropping + /// some ids at the end. + void truncateIdKind(IdKind kind, unsigned num); + /// Truncate the ids to the number in the space of the specified Counts. + void truncateIdKind(IdKind kind, const Counts &counts); + /// A parameter that controls detection of an unrealistic number of /// constraints. If the number of constraints is this many times the number of /// variables, we consider such a system out of line with the intended use diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -119,6 +119,31 @@ if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count))) poly.removeInequality(i - 1); } + +IntegerRelation::Counts IntegerRelation::getCounts() const { + return {PresburgerLocalSpace(*this), getNumInequalities(), + getNumEqualities()}; +} + +void IntegerRelation::truncateIdKind(IdKind kind, unsigned num) { + unsigned curNum = getNumIdKind(kind); + assert(num <= curNum && "Can't truncate to more ids!"); + removeIdRange(kind, num, curNum); +} + +void IntegerRelation::truncateIdKind(IdKind kind, const Counts &counts) { + truncateIdKind(kind, counts.getSpace().getNumIdKind(kind)); +} + +void IntegerRelation::truncate(const Counts &counts) { + truncateIdKind(IdKind::Domain, counts); + truncateIdKind(IdKind::Range, counts); + truncateIdKind(IdKind::Symbol, counts); + truncateIdKind(IdKind::Local, counts); + removeInequalityRange(counts.getNumIneqs(), getNumInequalities()); + removeInequalityRange(counts.getNumEqs(), getNumEqualities()); +} + unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -155,16 +155,12 @@ // rollback b to its initial state before returning, which we will do by // removing all constraints beyond the original number of inequalities // and equalities, so we store these counts first. - const unsigned bInitNumIneqs = b.getNumInequalities(); - const unsigned bInitNumEqs = b.getNumEqualities(); - const unsigned bInitNumLocals = b.getNumLocalIds(); + const IntegerRelation::Counts bCounts = b.getCounts(); // Similarly, we also want to rollback simplex to its original state. const unsigned initialSnapshot = simplex.getSnapshot(); auto restoreState = [&]() { - b.removeIdRange(IdKind::Local, bInitNumLocals, b.getNumLocalIds()); - b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities()); - b.removeEqualityRange(bInitNumEqs, b.getNumEqualities()); + b.truncate(bCounts); simplex.rollback(initialSnapshot); }; @@ -200,7 +196,8 @@ } unsigned offset = simplex.getNumConstraints(); - unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals; + unsigned numLocalsAdded = + b.getNumLocalIds() - bCounts.getSpace().getNumLocalIds(); simplex.appendVariable(numLocalsAdded); unsigned snapshotBeforeIntersect = simplex.getSnapshot();