diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -29,7 +29,7 @@ class IntegerPolyhedron; class PresburgerSet; class PresburgerRelation; -struct SymbolicLexMin; +struct SymbolicLexOpt; /// The type of bound: equal, lower bound or upper bound. enum class BoundType { EQ, LB, UB }; @@ -659,15 +659,39 @@ /// x = a if b <= a, a <= c /// x = b if a < b, b <= c /// - /// This function is stored in the `lexmin` function in the result. + /// This function is stored in the `lexopt` function in the result. /// Some assignments to the symbols might make the set empty. /// Such points are not part of the function's domain. /// In the above example, this happens when max(a, b) > c. /// /// For some values of the symbols, the lexmin may be unbounded. - /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate + /// `SymbolicLexOpt` stores these parts of the symbolic domain in a separate /// `PresburgerSet`, `unboundedDomain`. - SymbolicLexMin findSymbolicIntegerLexMin() const; + SymbolicLexOpt findSymbolicIntegerLexMin() const; + + /// Compute the symbolic integer lexmax of the relation. + /// + /// This finds, for every assignment to the symbols and domain, + /// the lexicographically maximum value attained by the range. + /// + /// For example, the symbolic lexmax of the set + /// + /// (x, y)[a, b, c] : (a <= x, x <= b, x <= c) + /// + /// can be written as + /// + /// x = b if b <= c, a <= b + /// x = c if c < b, a <= c + /// + /// This function is stored in the `lexopt` function in the result. + /// Some assignments to the symbols might make the set empty. + /// Such points are not part of the function's domain. + /// In the above example, this happens when min(b, c) < a. + /// + /// For some values of the symbols, the lexmax may be unbounded. + /// `SymbolicLexOpt` stores these parts of the symbolic domain in a separate + /// `PresburgerSet`, `unboundedDomain`. + SymbolicLexOpt findSymbolicIntegerLexMax() const; /// Return the set difference of this set and the given set, i.e., /// return `this \ set`. diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -43,7 +43,7 @@ /// these constraints that are redundant, i.e. a subset of constraints that /// doesn't constrain the affine set further after adding the non-redundant /// constraints. The LexSimplex class provides support for computing the -/// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class +/// lexicographic minimum of an IntegerRelation. The SymbolicLexOpt class /// provides support for computing symbolic lexicographic minimums. All of these /// classes can be constructed from an IntegerRelation, and all inherit common /// functionality from SimplexBase. @@ -529,18 +529,18 @@ std::optional maybeGetNonIntegralVarRow() const; }; -/// Represents the result of a symbolic lexicographic minimization computation. -struct SymbolicLexMin { - SymbolicLexMin(const PresburgerSpace &space) - : lexmin(space), +/// Represents the result of a symbolic lexicographic optimization computation. +struct SymbolicLexOpt { + SymbolicLexOpt(const PresburgerSpace &space) + : lexopt(space), unboundedDomain(PresburgerSet::getEmpty(space.getDomainSpace())) {} - /// This maps assignments of symbols to the corresponding lexmin. + /// This maps assignments of symbols to the corresponding lexopt. /// Takes no value when no integer sample exists for the assignment or if the - /// lexmin is unbounded. - PWMAFunction lexmin; - /// Contains all assignments to the symbols that made the lexmin unbounded. - /// Note that the symbols of the input set to the symbolic lexmin are dims + /// lexopt is unbounded. + PWMAFunction lexopt; + /// Contains all assignments to the symbols that made the lexopt unbounded. + /// Note that the symbols of the input set to the symbolic lexopt are dims /// of this PrebsurgerSet. PresburgerSet unboundedDomain; }; @@ -575,13 +575,13 @@ /// where it is. class SymbolicLexSimplex : public LexSimplexBase { public: - /// `constraints` is the set for which the symbolic lexmin will be computed. - /// `symbolDomain` is the set of values of the symbols for which the lexmin + /// `constraints` is the set for which the symbolic lexopt will be computed. + /// `symbolDomain` is the set of values of the symbols for which the lexopt /// will be computed. `symbolDomain` should have a dim var for every symbol in /// `constraints`, and no other vars. `isSymbol` specifies which vars of /// `constraints` should be considered as symbols. /// - /// The resulting SymbolicLexMin's space will be compatible with that of + /// The resulting SymbolicLexOpt's space will be compatible with that of /// symbolDomain. SymbolicLexSimplex(const IntegerRelation &constraints, const IntegerPolyhedron &symbolDomain, @@ -594,7 +594,7 @@ "there must be some non-symbols to optimize!"); } - /// An overload to select some subrange of ids as symbols for lexmin. + /// An overload to select some subrange of ids as symbols for lexopt. /// The symbol ids are the range of ids with absolute index /// [symbolOffset, symbolOffset + symbolDomain.getNumVars()) SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset, @@ -604,7 +604,7 @@ symbolOffset, symbolDomain.getNumVars())) {} - /// An overload to select the symbols of `constraints` as symbols for lexmin. + /// An overload to select the symbols of `constraints` as symbols for lexopt. SymbolicLexSimplex(const IntegerRelation &constraints, const IntegerPolyhedron &symbolDomain) : SymbolicLexSimplex(constraints, @@ -614,7 +614,7 @@ "symbolDomain must have as many vars as constraints has symbols!"); } - /// The lexmin will be stored as a function `lexmin` from symbols to + /// The lexmin will be stored as a function `lexopt` from symbols to /// non-symbols in the result. /// /// For some values of the symbols, the lexmin may be unbounded. @@ -622,7 +622,7 @@ /// /// The spaces of the sets in the result are compatible with the symbolDomain /// passed in the SymbolicLexSimplex constructor. - SymbolicLexMin computeSymbolicIntegerLexMin(); + SymbolicLexOpt computeSymbolicIntegerLexMin(); private: /// Perform all pivots that do not require branching. @@ -670,7 +670,7 @@ /// Record a lexmin. The tableau must be consistent with all variables /// having symbolic samples with integer coefficients. - void recordOutput(SymbolicLexMin &result) const; + void recordOutput(SymbolicLexOpt &result) const; /// The symbol domain. IntegerPolyhedron domainPoly; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -172,7 +172,7 @@ return PresburgerRelation(*this); // Move all the non-div locals to the end, as the current API to - // SymbolicLexMin requires these to form a contiguous range. + // SymbolicLexOpt requires these to form a contiguous range. // // Take a copy so we can perform mutations. IntegerRelation copy = *this; @@ -211,13 +211,13 @@ // exists, which is the union of the domain of the returned lexmin function // and the returned set of assignments to the "symbols" that makes the lexmin // unbounded. - SymbolicLexMin lexminResult = + SymbolicLexOpt lexminResult = SymbolicLexSimplex(copy, /*symbolOffset*/ 0, IntegerPolyhedron(PresburgerSpace::getSetSpace( /*numDims=*/copy.getNumVars() - numNonDivLocals))) .computeSymbolicIntegerLexMin(); PresburgerRelation result = - lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain); + lexminResult.lexopt.getDomain().unionSet(lexminResult.unboundedDomain); // The result set might lie in the wrong space -- all its ids are dims. // Set it to the desired space and return. @@ -227,7 +227,7 @@ return result; } -SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const { +SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMin() const { // Symbol and Domain vars will be used as symbols for symbolic lexmin. // In other words, for every value of the symbols and domain, return the // lexmin value of the (range, locals). @@ -239,7 +239,7 @@ // Compute the symbolic lexmin of the dims and locals, with the symbols being // the actual symbols of this set. // The resultant space of lexmin is the space of the relation itself. - SymbolicLexMin result = + SymbolicLexOpt result = SymbolicLexSimplex(*this, IntegerPolyhedron(PresburgerSpace::getSetSpace( /*numDims=*/getNumDomainVars(), @@ -249,11 +249,49 @@ // We want to return only the lexmin over the dims, so strip the locals from // the computed lexmin. - result.lexmin.removeOutputs(result.lexmin.getNumOutputs() - getNumLocalVars(), - result.lexmin.getNumOutputs()); + result.lexopt.removeOutputs(result.lexopt.getNumOutputs() - getNumLocalVars(), + result.lexopt.getNumOutputs()); return result; } +/// findSymbolicIntegerLexMax is implemented using findSymbolicIntegerLexMin as +/// follows: +/// 1. A new relation is created which is `this` relation with the sign of the +/// range flipped; +/// 2. findSymbolicIntegerLexMin is called on the range negated relation to +/// compute the negated lexmax of `this` relation; +/// 3. The sign of the negated lexmax is flipped and returned. +SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMax() const { + IntegerRelation flippedRel = *this; + // Flip range sign by flipping the sign of range variables in all constraints. + for (unsigned j = getNumDomainVars(), + b = getNumDomainVars() + getNumRangeVars(); + j < b; j++) { + for (unsigned i = 0, a = getNumEqualities(); i < a; i++) + flippedRel.atEq(i, j) = -1 * atEq(i, j); + for (unsigned i = 0, a = getNumInequalities(); i < a; i++) + flippedRel.atIneq(i, j) = -1 * atIneq(i, j); + } + // Compute negated lexmax by computing lexmin. + SymbolicLexOpt flippedSymbolicIntegerLexMax = + flippedRel.findSymbolicIntegerLexMin(), + symbolicIntegerLexMax( + flippedSymbolicIntegerLexMax.lexopt.getSpace()); + // Get lexmax by flipping range sign in the PWMA constraints. + for (auto &flippedPiece : + flippedSymbolicIntegerLexMax.lexopt.getAllPieces()) { + Matrix mat = flippedPiece.output.getOutputMatrix(); + for (unsigned i = 0, e = mat.getNumRows(); i < e; i++) + mat.negateRow(i); + MultiAffineFunction maf(flippedPiece.output.getSpace(), mat); + PWMAFunction::Piece piece = {flippedPiece.domain, maf}; + symbolicIntegerLexMax.lexopt.addPiece(piece); + } + symbolicIntegerLexMax.unboundedDomain = + flippedSymbolicIntegerLexMax.unboundedDomain; + return symbolicIntegerLexMax; +} + PresburgerRelation IntegerRelation::subtract(const PresburgerRelation &set) const { return PresburgerRelation(*this).subtract(set); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -435,9 +435,9 @@ return moveRowUnknownToColumn(cutRow); } -void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const { +void SymbolicLexSimplex::recordOutput(SymbolicLexOpt &result) const { Matrix output(0, domainPoly.getNumVars() + 1); - output.reserveRows(result.lexmin.getNumOutputs()); + output.reserveRows(result.lexopt.getNumOutputs()); for (const Unknown &u : var) { if (u.isSymbol) continue; @@ -469,10 +469,10 @@ } // Store the output in a MultiAffineFunction and add it the result. - PresburgerSpace funcSpace = result.lexmin.getSpace(); + PresburgerSpace funcSpace = result.lexopt.getSpace(); funcSpace.insertVar(VarKind::Local, 0, domainPoly.getNumLocalVars()); - result.lexmin.addPiece( + result.lexopt.addPiece( {PresburgerSet(domainPoly), MultiAffineFunction(funcSpace, output, domainPoly.getLocalReprs())}); } @@ -515,8 +515,8 @@ return success(); } -SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { - SymbolicLexMin result(PresburgerSpace::getRelationSpace( +SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() { + SymbolicLexOpt result(PresburgerSpace::getRelationSpace( /*numDomain=*/domainPoly.getNumDimVars(), /*numRange=*/var.size() - nSymbol, /*numSymbols=*/domainPoly.getNumSymbolVars())); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -1198,13 +1198,13 @@ ASSERT_NE(poly.getNumDimVars(), 0u); ASSERT_NE(poly.getNumSymbolVars(), 0u); - SymbolicLexMin result = poly.findSymbolicIntegerLexMin(); + SymbolicLexOpt result = poly.findSymbolicIntegerLexMin(); if (expectedLexminRepr.empty()) { - EXPECT_TRUE(result.lexmin.getDomain().isIntegerEmpty()); + EXPECT_TRUE(result.lexopt.getDomain().isIntegerEmpty()); } else { PWMAFunction expectedLexmin = parsePWMAF(expectedLexminRepr); - EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin)); + EXPECT_TRUE(result.lexopt.isEqual(expectedLexmin)); } if (expectedUnboundedDomainRepr.empty()) { diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -126,7 +126,7 @@ } TEST(IntegerRelationTest, symbolicLexmin) { - SymbolicLexMin lexmin = + SymbolicLexOpt lexmin = parseRelationFromSet("(a, x)[b] : (x - a >= 0, x - b >= 0)", 1) .findSymbolicIntegerLexMin(); @@ -135,5 +135,43 @@ {"(a)[b] : (b - a - 1 >= 0)", "(a)[b] -> (b)"}, // b }); EXPECT_TRUE(lexmin.unboundedDomain.isIntegerEmpty()); - EXPECT_TRUE(lexmin.lexmin.isEqual(expectedLexmin)); + EXPECT_TRUE(lexmin.lexopt.isEqual(expectedLexmin)); +} + +TEST(IntegerRelationTest, symbolicLexmax) { + SymbolicLexOpt lexmax1 = + parseRelationFromSet("(a, x)[b] : (a - x >= 0, b - x >= 0)", 1) + .findSymbolicIntegerLexMax(); + + PWMAFunction expectedLexmax1 = parsePWMAF({ + {"(a)[b] : (a - b >= 0)", "(a)[b] -> (b)"}, + {"(a)[b] : (b - a - 1 >= 0)", "(a)[b] -> (a)"}, + }); + + SymbolicLexOpt lexmax2 = + parseRelationFromSet("(i, j)[N] : (i >= 0, j >= 0, N - i - j >= 0)", 1) + .findSymbolicIntegerLexMax(); + + PWMAFunction expectedLexmax2 = parsePWMAF({ + {"(i)[N] : (i >= 0, N - i >= 0)", "(i)[N] -> (N - i)"}, + }); + + SymbolicLexOpt lexmax3 = + parseRelationFromSet("(x, y)[N] : (x >= 0, 2 * N - x >= 0, y >= 0, x - y " + "+ 2 * N >= 0, 4 * N - x - y >= 0)", + 1) + .findSymbolicIntegerLexMax(); + + PWMAFunction expectedLexmax3 = + parsePWMAF({{"(x)[N] : (x >= 0, 2 * N - x >= 0, x - N - 1 >= 0)", + "(x)[N] -> (4 * N - x)"}, + {"(x)[N] : (x >= 0, 2 * N - x >= 0, -x + N >= 0)", + "(x)[N] -> (x + 2 * N)"}}); + + EXPECT_TRUE(lexmax1.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmax1.lexopt.isEqual(expectedLexmax1)); + EXPECT_TRUE(lexmax2.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmax2.lexopt.isEqual(expectedLexmax2)); + EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexmax3)); }