diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -144,6 +144,30 @@ return inequalities.getRow(idx); } + /// The struct CountsSnapshot stores the count of each IdKind, and also of + /// each constraint type. getCounts() returns a CountsSnapshot object + /// describing the current state of the IntegerRelation. truncate() truncates + /// all ids of each IdKind and all constraints of both kinds beyond the counts + /// in the specified CountsSnapshot object. This can be used to achieve + /// rudimentary rollback support. As long as none of the existing constraints + /// or ids are disturbed, and only additional ids or constraints are added, + /// this addition can be rolled back using truncate. + struct CountsSnapshot { + public: + CountsSnapshot(const PresburgerLocalSpace &space, unsigned numIneqs, + unsigned numEqs) + : space(space), numIneqs(numIneqs), numEqs(numEqs) {} + const PresburgerLocalSpace &getSpace() const { return space; }; + unsigned getNumIneqs() const { return numIneqs; } + unsigned getNumEqs() const { return numEqs; } + + private: + PresburgerLocalSpace space; + unsigned numIneqs, numEqs; + }; + CountsSnapshot getCounts() const; + void truncate(const CountsSnapshot &counts); + /// Insert `num` identifiers of the specified kind at position `pos`. /// Positions are relative to the kind of identifier. The coefficient columns /// corresponding to the added identifiers are initialized to zero. Return the @@ -482,6 +506,11 @@ /// arrays as needed. void removeIdRange(unsigned idStart, unsigned idLimit); + using PresburgerSpace::truncateIdKind; + /// Truncate the ids to the number in the space of the specified + /// CountsSnapshot. + void truncateIdKind(IdKind kind, const CountsSnapshot &counts); + /// A parameter that controls detection of an unrealistic number of /// constraints. If the number of constraints is this many times the number of /// variables, we consider such a system out of line with the intended use diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -109,6 +109,10 @@ /// idLimit). The range is relative to the kind of identifier. virtual void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); + /// Truncate the ids of the specified kind to the specified number by dropping + /// some ids at the end. `num` must be less than the current number. + void truncateIdKind(IdKind kind, unsigned num); + /// Returns true if both the spaces are equal i.e. if both spaces have the /// same number of identifiers of each kind (excluding Local Identifiers). bool isEqual(const PresburgerSpace &other) const; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -119,6 +119,26 @@ if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count))) poly.removeInequality(i - 1); } + +IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const { + return {PresburgerLocalSpace(*this), getNumInequalities(), + getNumEqualities()}; +} + +void IntegerRelation::truncateIdKind(IdKind kind, + const CountsSnapshot &counts) { + truncateIdKind(kind, counts.getSpace().getNumIdKind(kind)); +} + +void IntegerRelation::truncate(const CountsSnapshot &counts) { + truncateIdKind(IdKind::Domain, counts); + truncateIdKind(IdKind::Range, counts); + truncateIdKind(IdKind::Symbol, counts); + truncateIdKind(IdKind::Local, counts); + removeInequalityRange(counts.getNumIneqs(), getNumInequalities()); + removeInequalityRange(counts.getNumEqs(), getNumEqualities()); +} + unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -155,16 +155,12 @@ // rollback b to its initial state before returning, which we will do by // removing all constraints beyond the original number of inequalities // and equalities, so we store these counts first. - const unsigned bInitNumIneqs = b.getNumInequalities(); - const unsigned bInitNumEqs = b.getNumEqualities(); - const unsigned bInitNumLocals = b.getNumLocalIds(); + const IntegerRelation::CountsSnapshot bCounts = b.getCounts(); // Similarly, we also want to rollback simplex to its original state. const unsigned initialSnapshot = simplex.getSnapshot(); auto restoreState = [&]() { - b.removeIdRange(IdKind::Local, bInitNumLocals, b.getNumLocalIds()); - b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities()); - b.removeEqualityRange(bInitNumEqs, b.getNumEqualities()); + b.truncate(bCounts); simplex.rollback(initialSnapshot); }; @@ -200,7 +196,8 @@ } unsigned offset = simplex.getNumConstraints(); - unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals; + unsigned numLocalsAdded = + b.getNumLocalIds() - bCounts.getSpace().getNumLocalIds(); simplex.appendVariable(numLocalsAdded); unsigned snapshotBeforeIntersect = simplex.getSnapshot(); diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -91,6 +91,12 @@ llvm_unreachable("PresburgerSpace does not support local identifiers!"); } +void PresburgerSpace::truncateIdKind(IdKind kind, unsigned num) { + unsigned curNum = getNumIdKind(kind); + assert(num <= curNum && "Can't truncate to more ids!"); + removeIdRange(kind, num, curNum); +} + unsigned PresburgerLocalSpace::insertId(IdKind kind, unsigned pos, unsigned num) { if (kind == IdKind::Local) {