diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -755,9 +755,9 @@ unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; /// Compute an equivalent representation of the same set, such that all local - /// ids have division representations. This representation may involve - /// local ids that correspond to divisions, and may also be a union of convex - /// disjuncts. + /// vars in all disjuncts have division representations. This representation + /// may involve local vars that correspond to divisions, and may also be a + /// union of convex disjuncts. PresburgerSet computeReprWithOnlyDivLocals() const; /// Compute the symbolic integer lexmin of the polyhedron. diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -128,6 +128,12 @@ /// Check whether all local ids in all disjuncts have a div representation. bool hasOnlyDivLocals() const; + /// Compute an equivalent representation of the same set, such that all local + /// ids in all disjuncts have division representations. This representation + /// may involve local ids that correspond to divisions, and may also be a + /// union of convex disjuncts. + PresburgerRelation computeReprWithOnlyDivLocals() const; + /// Print the set's internal state. void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -570,7 +570,7 @@ /// `symbolDomain` is the set of values of the symbols for which the lexmin /// will be computed. `symbolDomain` should have a dim id for every symbol in /// `constraints`, and no other ids. - SymbolicLexSimplex(const IntegerPolyhedron &constraints, + SymbolicLexSimplex(const IntegerRelation &constraints, const IntegerPolyhedron &symbolDomain) : SymbolicLexSimplex(constraints, constraints.getIdKindOffset(IdKind::Symbol), @@ -582,8 +582,7 @@ /// The symbol ids are the range of ids with absolute index /// [symbolOffset, symbolOffset + symbolDomain.getNumIds()) /// symbolDomain should only have dim ids. - SymbolicLexSimplex(const IntegerPolyhedron &constraints, - unsigned symbolOffset, + SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset, const IntegerPolyhedron &symbolDomain) : LexSimplexBase(/*nVar=*/constraints.getNumIds(), symbolOffset, symbolDomain.getNumIds()), diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -136,6 +136,18 @@ return getNegatedCoeffs(eqCoeffs); } +PresburgerRelation PresburgerRelation::computeReprWithOnlyDivLocals() const { + if (hasOnlyDivLocals()) + return *this; + + // The result is just the union of the reprs of the disjuncts. + PresburgerRelation result(getSpace()); + for (const IntegerRelation &disjunct : disjuncts) + result.unionInPlace( + ((IntegerPolyhedron)disjunct).computeReprWithOnlyDivLocals()); + return result; +} + /// Return the set difference b \ s. /// /// In the following, U denotes union, /\ denotes intersection, \ denotes set @@ -174,6 +186,9 @@ if (b.isEmptyByGCDTest()) return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals()); + if (!s.hasOnlyDivLocals()) + return getSetDifference(b, s.computeReprWithOnlyDivLocals()); + // Remove duplicate divs up front here to avoid existing // divs disappearing in the call to mergeLocalIds below. b.removeDuplicateDivs(); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -431,6 +431,10 @@ EXPECT_TRUE(s.isEqual(t)); } +void expectEqual(const IntegerPolyhedron &s, const IntegerPolyhedron &t) { + EXPECT_TRUE(s.isEqual(t)); +} + void expectEmpty(const PresburgerSet &s) { EXPECT_TRUE(s.isIntegerEmpty()); } TEST(SetTest, divisions) { @@ -505,6 +509,45 @@ expectEqual(evens, PresburgerSet(evensDefByIneq)); } +TEST(SetTest, divisionNonDivLocals) { + // This is a tetrahedron with vertices at + // (1/3, 0, 0), (2/3, 0, 0), (2/3, 0, 1000), and (1000, 1000, 1000). + // + // The only integer point in this is at (1000, 1000, 1000). + // We project this to the xy plane. + IntegerPolyhedron tetrahedron = + parsePolyAndMakeLocals("(x, y, z) : (y >= 0, z - y >= 0, 3000*x - 2998*y " + "- 1000 - z >= 0, -1500*x + 1499*y + 1000 >= 0)", + /*numLocals=*/1); + + // This is a triangle with vertices at (1/3, 0), (2/3, 0) and (1000, 1000). + // The only integer point in this is at (1000, 1000). + // + // It also happens to be the projection of the above onto the xy plane. + IntegerPolyhedron triangle = parsePoly("(x,y) : (y >= 0, " + "3000 * x - 2999 * y - 1000 >= 0, " + "-3000 * x + 2998 * y + 2000 >= 0)"); + EXPECT_TRUE(triangle.containsPoint({1000, 1000})); + EXPECT_FALSE(triangle.containsPoint({1001, 1001})); + // expectEqual(triangle, tetrahedron); + + convertSuffixDimsToLocals(triangle, 1); + IntegerPolyhedron line = parsePoly("(x) : (x - 1000 == 0)"); + expectEqual(line, triangle); + + // Triangle with vertices (0, 0), (5, 0), (15, 5). + // Projected on x, it becomes [0, 13] U {15} as it becomes too narrow towards + // the apex and so does not have have any integer point at x = 14. + // At x = 15, the apex is an integer point. + PresburgerSet triangle2{parsePolyAndMakeLocals("(x,y) : (y >= 0, " + "x - 3*y >= 0, " + "2*y - x + 5 >= 0)", + /*numLocals=*/1)}; + PresburgerSet zeroToThirteen{parsePoly("(x) : (13 - x >= 0, x >= 0)")}; + PresburgerSet fifteen{parsePoly("(x) : (x - 15 == 0)")}; + expectEqual(triangle2.subtract(zeroToThirteen), fifteen); +} + TEST(SetTest, subtractDuplicateDivsRegression) { // Previously, subtracting sets with duplicate divs might result in crashes // due to existing divs being removed when merging local ids, due to being