diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -27,6 +27,7 @@ class IntegerRelation; class IntegerPolyhedron; class PresburgerSet; +class PresburgerRelation; /// An IntegerRelation represents the set of points from a PresburgerSpace that /// satisfy a list of affine constraints. Affine constraints can be inequalities @@ -575,6 +576,12 @@ /// this for uniformity with `applyDomain`. void applyRange(const IntegerRelation &rel); + /// Compute an equivalent representation of the same set, such that all local + /// vars in all disjuncts have division representations. This representation + /// may involve local vars that correspond to divisions, and may also be a + /// union of convex disjuncts. + PresburgerRelation computeReprWithOnlyDivLocals() const; + void print(raw_ostream &os) const; void dump() const; @@ -760,12 +767,6 @@ /// first added variable. unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override; - /// Compute an equivalent representation of the same set, such that all local - /// ids have division representations. This representation may involve - /// local ids that correspond to divisions, and may also be a union of convex - /// disjuncts. - PresburgerSet computeReprWithOnlyDivLocals() const; - /// Compute the symbolic integer lexmin of the polyhedron. /// This finds, for every assignment to the symbols, the lexicographically /// minimum value attained by the dimensions. For example, the symbolic lexmin diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -128,6 +128,12 @@ /// Check whether all local ids in all disjuncts have a div representation. bool hasOnlyDivLocals() const; + /// Compute an equivalent representation of the same relation, such that all + /// local ids in all disjuncts have division representations. This + /// representation may involve local ids that correspond to divisions, and may + /// also be a union of convex disjuncts. + PresburgerRelation computeReprWithOnlyDivLocals() const; + /// Print the set's internal state. void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -570,7 +570,7 @@ /// `symbolDomain` is the set of values of the symbols for which the lexmin /// will be computed. `symbolDomain` should have a dim var for every symbol in /// `constraints`, and no other vars. - SymbolicLexSimplex(const IntegerPolyhedron &constraints, + SymbolicLexSimplex(const IntegerRelation &constraints, const IntegerPolyhedron &symbolDomain) : SymbolicLexSimplex(constraints, constraints.getVarKindOffset(VarKind::Symbol), @@ -582,8 +582,7 @@ /// The symbol ids are the range of ids with absolute index /// [symbolOffset, symbolOffset + symbolDomain.getNumVars()) /// symbolDomain should only have dim ids. - SymbolicLexSimplex(const IntegerPolyhedron &constraints, - unsigned symbolOffset, + SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset, const IntegerPolyhedron &symbolDomain) : LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset, symbolDomain.getNumVars()), diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -165,16 +165,16 @@ removeEqualityRange(counts.getNumEqs(), getNumEqualities()); } -PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const { +PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const { // If there are no locals, we're done. if (getNumLocalVars() == 0) - return PresburgerSet(*this); + return PresburgerRelation(*this); // Move all the non-div locals to the end, as the current API to // SymbolicLexMin requires these to form a contiguous range. // // Take a copy so we can perform mutations. - IntegerPolyhedron copy = *this; + IntegerRelation copy = *this; std::vector reprs; copy.getLocalReprs(reprs); @@ -197,7 +197,7 @@ // If there are no non-div locals, we're done. if (numNonDivLocals == 0) - return PresburgerSet(*this); + return PresburgerRelation(*this); // We computeSymbolicIntegerLexMin by considering the non-div locals as // "non-symbols" and considering everything else as "symbols". This will diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -136,6 +136,17 @@ return getNegatedCoeffs(eqCoeffs); } +PresburgerRelation PresburgerRelation::computeReprWithOnlyDivLocals() const { + if (hasOnlyDivLocals()) + return *this; + + // The result is just the union of the reprs of the disjuncts. + PresburgerRelation result(getSpace()); + for (const IntegerRelation &disjunct : disjuncts) + result.unionInPlace(disjunct.computeReprWithOnlyDivLocals()); + return result; +} + /// Return the set difference b \ s. /// /// In the following, U denotes union, /\ denotes intersection, \ denotes set @@ -174,6 +185,9 @@ if (b.isEmptyByGCDTest()) return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals()); + if (!s.hasOnlyDivLocals()) + return getSetDifference(b, s.computeReprWithOnlyDivLocals()); + // Remove duplicate divs up front here to avoid existing // divs disappearing in the call to mergeLocalVars below. b.removeDuplicateDivs(); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -431,6 +431,10 @@ EXPECT_TRUE(s.isEqual(t)); } +void expectEqual(const IntegerPolyhedron &s, const IntegerPolyhedron &t) { + EXPECT_TRUE(s.isEqual(t)); +} + void expectEmpty(const PresburgerSet &s) { EXPECT_TRUE(s.isIntegerEmpty()); } TEST(SetTest, divisions) { @@ -505,6 +509,45 @@ expectEqual(evens, PresburgerSet(evensDefByIneq)); } +TEST(SetTest, divisionNonDivLocals) { + // This is a tetrahedron with vertices at + // (1/3, 0, 0), (2/3, 0, 0), (2/3, 0, 1000), and (1000, 1000, 1000). + // + // The only integer point in this is at (1000, 1000, 1000). + // We project this to the xy plane. + IntegerPolyhedron tetrahedron = + parsePolyAndMakeLocals("(x, y, z) : (y >= 0, z - y >= 0, 3000*x - 2998*y " + "- 1000 - z >= 0, -1500*x + 1499*y + 1000 >= 0)", + /*numLocals=*/1); + + // This is a triangle with vertices at (1/3, 0), (2/3, 0) and (1000, 1000). + // The only integer point in this is at (1000, 1000). + // + // It also happens to be the projection of the above onto the xy plane. + IntegerPolyhedron triangle = parsePoly("(x,y) : (y >= 0, " + "3000 * x - 2999 * y - 1000 >= 0, " + "-3000 * x + 2998 * y + 2000 >= 0)"); + EXPECT_TRUE(triangle.containsPoint({1000, 1000})); + EXPECT_FALSE(triangle.containsPoint({1001, 1001})); + expectEqual(triangle, tetrahedron); + + convertSuffixDimsToLocals(triangle, 1); + IntegerPolyhedron line = parsePoly("(x) : (x - 1000 == 0)"); + expectEqual(line, triangle); + + // Triangle with vertices (0, 0), (5, 0), (15, 5). + // Projected on x, it becomes [0, 13] U {15} as it becomes too narrow towards + // the apex and so does not have have any integer point at x = 14. + // At x = 15, the apex is an integer point. + PresburgerSet triangle2{parsePolyAndMakeLocals("(x,y) : (y >= 0, " + "x - 3*y >= 0, " + "2*y - x + 5 >= 0)", + /*numLocals=*/1)}; + PresburgerSet zeroToThirteen{parsePoly("(x) : (13 - x >= 0, x >= 0)")}; + PresburgerSet fifteen{parsePoly("(x) : (x - 15 == 0)")}; + expectEqual(triangle2.subtract(zeroToThirteen), fifteen); +} + TEST(SetTest, subtractDuplicateDivsRegression) { // Previously, subtracting sets with duplicate divs might result in crashes // due to existing divs being removed when merging local ids, due to being @@ -797,7 +840,7 @@ unsigned numToProject) { poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject, poly.getNumDimVars(), VarKind::Local); - PresburgerSet repr = poly.computeReprWithOnlyDivLocals(); + PresburgerRelation repr = poly.computeReprWithOnlyDivLocals(); EXPECT_TRUE(repr.hasOnlyDivLocals()); EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace())); for (const SmallVector &point : points) { @@ -810,7 +853,7 @@ unsigned numToProject) { poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject, poly.getNumDimVars(), VarKind::Local); - PresburgerSet repr = poly.computeReprWithOnlyDivLocals(); + PresburgerRelation repr = poly.computeReprWithOnlyDivLocals(); EXPECT_TRUE(repr.hasOnlyDivLocals()); EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace())); EXPECT_TRUE(repr.isEqual(expected));