diff --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp --- a/mlir/lib/Analysis/PresburgerSet.cpp +++ b/mlir/lib/Analysis/PresburgerSet.cpp @@ -190,7 +190,25 @@ return; } FlatAffineConstraints sI = s.getFlatAffineConstraints(i); - unsigned bInitNumLocals = b.getNumLocalIds(); + + // Below, we append some additional constraints and ids to b. We want to + // rollback b to its initial state before returning, which we will do by + // removing all constraints beyond the original number of inequalities + // and equalities, so we store these counts first. + const unsigned bInitNumIneqs = b.getNumInequalities(); + const unsigned bInitNumEqs = b.getNumEqualities(); + const unsigned bInitNumLocals = b.getNumLocalIds(); + // Similarly, we also want to rollback simplex to its original state. + const unsigned initialSnapshot = simplex.getSnapshot(); + + // Automatically restore the original state when we return. + auto restoreState = [&]() { + b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals, + b.getNumLocalIds()); + b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities()); + b.removeEqualityRange(bInitNumEqs, b.getNumEqualities()); + simplex.rollback(initialSnapshot); + }; // Find out which inequalities of sI correspond to division inequalities for // the local variables of sI. @@ -219,7 +237,6 @@ isDivInequality[maybePair->second] = true; } - unsigned initialSnapshot = simplex.getSnapshot(); unsigned offset = simplex.getNumConstraints(); unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals; simplex.appendVariable(numLocalsAdded); @@ -229,9 +246,9 @@ if (simplex.isEmpty()) { /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1. - simplex.rollback(initialSnapshot); - b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals, - b.getNumLocalIds()); + /// We are ignoring level i completely, so we restore the state + /// *before* going to level i + 1. + restoreState(); subtractRecursively(b, simplex, s, i + 1, result); return; } @@ -270,13 +287,6 @@ simplex.addInequality(ineq); }; - // processInequality appends some additional constraints to b. We want to - // rollback b to its initial state before returning, which we will do by - // removing all constraints beyond the original number of inequalities - // and equalities, so we store these counts first. - unsigned bInitNumIneqs = b.getNumInequalities(); - unsigned bInitNumEqs = b.getNumEqualities(); - // Process all the inequalities, ignoring redundant inequalities and division // inequalities. The result is correct whether or not we ignore these, but // ignoring them makes the result simpler. @@ -302,13 +312,7 @@ processInequality(getNegatedCoeffs(coeffs)); } - // Rollback b and simplex to their initial states. - b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals, - b.getNumLocalIds()); - b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities()); - b.removeEqualityRange(bInitNumEqs, b.getNumEqualities()); - - simplex.rollback(initialSnapshot); + restoreState(); } /// Return the set difference fac \ set. diff --git a/mlir/unittests/Analysis/PresburgerSetTest.cpp b/mlir/unittests/Analysis/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/PresburgerSetTest.cpp @@ -22,6 +22,17 @@ namespace mlir { +/// Parses a FlatAffineConstraints from a StringRef. It is expected that the +/// string represents a valid IntegerSet, otherwise it will violate a gtest +/// assertion. +static FlatAffineConstraints parseFAC(StringRef str, MLIRContext *context) { + FailureOr fac = parseIntegerSetToFAC(str, context); + + EXPECT_TRUE(succeeded(fac)); + + return *fac; +} + /// Compute the union of s and t, and check that each of the given points /// belongs to the union iff it belongs to at least one of s and t. static void testUnionAtPoints(PresburgerSet s, PresburgerSet t, @@ -620,6 +631,7 @@ void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); } TEST(SetTest, divisions) { + MLIRContext context; // Note: we currently need to add the equalities as inequalities to the FAC // since detecting divisions based on equalities is not yet supported. @@ -644,17 +656,12 @@ expectEqual(odds.complement(), evens); // even multiples of 3 = multiples of 6. expectEqual(multiples3.intersect(evens), multiples6); -} - -/// Parses a FlatAffineConstraints from a StringRef. It is expected that the -/// string represents a valid IntegerSet, otherwise it will violate a gtest -/// assertion. -static FlatAffineConstraints parseFAC(StringRef str, MLIRContext *context) { - FailureOr fac = parseIntegerSetToFAC(str, context); - - EXPECT_TRUE(succeeded(fac)); - return *fac; + PresburgerSet setA = + makeSetFromFACs(1, {parseFAC("(x) : (-x >= 0)", &context)}); + PresburgerSet setB = + makeSetFromFACs(1, {parseFAC("(x) : (x floordiv 2 - 4 >= 0)", &context)}); + EXPECT_TRUE(setA.subtract(setB).isEqual(setA)); } /// Coalesce `set` and check that the `newSet` is equal to `set and that