diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -26,6 +26,7 @@ class IntegerRelation; class IntegerPolyhedron; +class PresburgerSet; /// An IntegerRelation represents the set of points from a PresburgerSpace that /// satisfy a list of affine constraints. Affine constraints can be inequalities @@ -497,6 +498,9 @@ /// locals that have been added to `this`. unsigned mergeLocalIds(IntegerRelation &other); + /// Check whether all local ids have a division representation. + bool hasOnlyDivLocals() const; + /// Changes the partition between dimensions and symbols. Depending on the new /// symbol count, either a chunk of dimensional identifiers immediately before /// the split become symbols, or some of the symbols immediately after the @@ -739,6 +743,12 @@ /// first added identifier. unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; + /// Compute an equivalent representation of the same set, such that all local + /// ids have division representations. This representation may involve + /// local ids that correspond to divisions, and may also be a union of convex + /// disjuncts (IntegerRelations). + PresburgerSet computeReprWithOnlyDivLocals() const; + /// Compute the symbolic integer lexmin of the polyhedron. /// This finds, for every assignment to the symbols, the lexicographically /// minimum value attained by the dimensions. For example, the symbolic lexmin diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -117,6 +117,9 @@ /// disjuncts in the union. PresburgerRelation coalesce() const; + /// Check whether all local ids in all disjuncts have a div representation. + bool hasOnlyDivLocals() const; + /// Print the set's internal state. void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -572,10 +572,28 @@ /// `constraints`, and no other ids. SymbolicLexSimplex(const IntegerPolyhedron &constraints, const IntegerPolyhedron &symbolDomain) - : LexSimplexBase(constraints), domainPoly(symbolDomain), - domainSimplex(symbolDomain) { - assert(domainPoly.getNumIds() == constraints.getNumSymbolIds()); - assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds()); + : SymbolicLexSimplex(constraints, + constraints.getIdKindOffset(IdKind::Symbol), + symbolDomain) { + assert(constraints.getNumSymbolIds() == symbolDomain.getNumIds()); + } + + /// An overload to select some other subrange of ids as symbols for lexmin. + /// The symbol ids are the range of ids with absolute index + /// [symbolOffset, symbolOffset + symbolDomain.getNumIds()) + /// symbolDomain should only have dim ids. + SymbolicLexSimplex(const IntegerPolyhedron &constraints, + unsigned symbolOffset, + const IntegerPolyhedron &symbolDomain) + : LexSimplexBase(/*nVar=*/constraints.getNumIds(), symbolOffset, + symbolDomain.getNumIds()), + domainPoly(symbolDomain), domainSimplex(symbolDomain) { + // TODO consider supporting this case. It amounts + // to just returning the input constraints. + assert(domainPoly.getNumIds() > 0 && + "there must be some non-symbols to optimize!"); + assert(domainPoly.getNumIds() == domainPoly.getNumDimIds()); + intersectIntegerRelation(constraints); } /// The lexmin will be stored as a function `lexmin` from symbols to diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -152,6 +152,51 @@ removeEqualityRange(counts.getNumEqs(), getNumEqualities()); } +PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const { + // If there are no locals, we're done. + if (getNumLocalIds() == 0) + return PresburgerSet(*this); + + // Move all the non-div locals to the end, as we need these to form a + // contiguous range. + IntegerPolyhedron copy = *this; + std::vector reprs; + copy.getLocalReprs(reprs); + unsigned offset = copy.getIdKindOffset(IdKind::Local); + unsigned numNonDivLocals = 0; + for (unsigned i = 0, e = copy.getNumLocalIds(); i < e - numNonDivLocals;) { + if (!reprs[i]) { + copy.swapId(offset + i, offset + e - numNonDivLocals - 1); + std::swap(reprs[i], reprs[e - numNonDivLocals - 1]); + ++numNonDivLocals; + continue; + } + ++i; + } + + // If there are no non-div locals, we're done. + if (numNonDivLocals == 0) + return PresburgerSet(*this); + + // We computeSymbolicIntegerLexMin by considering the non-div locals as + // "non-symbols" and considering everything else as "symbols". This will + // compute a function mapping assignments to "symbols" to the + // lexicographically minimal valid assignment of "non-symbols", when a + // satisfying assignment exists. It separately returns the set of assignments + // to the "symbols" such that a satisfying assignment to the "non-symbols" + // exists but the lexmin is unbounded. We basically want to find the set of + // values of the "symbols" such that an assignment to the "non-symbols" + // exists, which is the union of the domain of the returned lexmin function + // and the returned set of assignments to the "symbols" that makes the lexmin + // unbounded. + SymbolicLexMin result = + SymbolicLexSimplex(copy, /*symbolOffset*/ 0, + IntegerPolyhedron(PresburgerSpace::getSetSpace( + /*numDims=*/copy.getNumIds() - numNonDivLocals))) + .computeSymbolicIntegerLexMin(); + return result.lexmin.getDomain().unionSet(result.unboundedDomain); +} + SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const { // Compute the symbolic lexmin of the dims and locals, with the symbols being // the actual symbols of this set. @@ -1122,6 +1167,13 @@ return relA.getNumLocalIds() - oldALocals; } +bool IntegerRelation::hasOnlyDivLocals() const { + std::vector reprs; + getLocalReprs(reprs); + return llvm::all_of(reprs, + [](const MaybeLocalRepr &repr) { return bool(repr); }); +} + void IntegerRelation::removeDuplicateDivs() { std::vector> divs; SmallVector denoms; diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -63,7 +63,7 @@ /// A point is contained in the union iff any of the parts contain the point. bool PresburgerRelation::containsPoint(ArrayRef point) const { return llvm::any_of(disjuncts, [&](const IntegerRelation &disjunct) { - return (disjunct.containsPoint(point)); + return (disjunct.containsPointNoLocal(point)); }); } @@ -770,6 +770,12 @@ return SetCoalescer(*this).coalesce(); } +bool PresburgerRelation::hasOnlyDivLocals() const { + return llvm::all_of(disjuncts, [](const IntegerRelation &rel) { + return rel.hasOnlyDivLocals(); + }); +} + void PresburgerRelation::print(raw_ostream &os) const { os << "Number of Disjuncts: " << getNumDisjuncts() << "\n"; for (const IntegerRelation &disjunct : disjuncts) { diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -747,6 +747,44 @@ /*resultBound=*/{}); } +// The last `numToProject` dims will be projected out, i.e., converted to +// locals. +void testComputeReprAtPoints(IntegerPolyhedron poly, + ArrayRef> points, + unsigned numToProject) { + poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject, + poly.getNumDimIds(), IdKind::Local); + PresburgerSet repr = poly.computeReprWithOnlyDivLocals(); + EXPECT_TRUE(repr.hasOnlyDivLocals()); + for (const SmallVector &point : points) { + EXPECT_EQ(poly.containsPointNoLocal(point).hasValue(), + repr.containsPoint(point)); + } +} +void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected, + unsigned numToProject) { + poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject, + poly.getNumDimIds(), IdKind::Local); + PresburgerSet repr = poly.computeReprWithOnlyDivLocals(); + EXPECT_TRUE(repr.hasOnlyDivLocals()); + EXPECT_TRUE(repr.isEqual(expected)); +} + +TEST(SetTest, computeReprWithOnlyDivLocals) { + testComputeReprAtPoints(parsePoly("(x, y) : (x - 2*y == 0)"), + {{1, 0}, {2, 1}, {3, 0}, {4, 2}, {5, 3}}, + /*numToProject=*/0); + testComputeReprAtPoints(parsePoly("(x, e) : (x - 2*e == 0)"), + {{1}, {2}, {3}, {4}, {5}}, /*numToProject=*/1); + + // Bezout's lemma: if a, b are constants, + // the set of values that ax + by can take is all multiples of gcd(a, b). + testComputeRepr( + parsePoly("(x, e, f) : (x - 15*e - 21*f == 0)"), + PresburgerSet(parsePoly({"(x) : (x - 3*(x floordiv 3) == 0)"})), + /*numToProject=*/2); +} + TEST(SetTest, subtractOutputSizeRegression) { PresburgerSet set1 = parsePresburgerSetFromPolyStrings(1, {"(i) : (i >= 0, 10 - i >= 0)"});