diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -29,7 +29,7 @@ class IntegerPolyhedron; class PresburgerSet; class PresburgerRelation; -struct SymbolicLexMin; +struct SymbolicLexOpt; /// The type of bound: equal, lower bound or upper bound. enum class BoundType { EQ, LB, UB }; @@ -659,15 +659,39 @@ /// x = a if b <= a, a <= c /// x = b if a < b, b <= c /// - /// This function is stored in the `lexmin` function in the result. + /// This function is stored in the `lexopt` function in the result. /// Some assignments to the symbols might make the set empty. /// Such points are not part of the function's domain. /// In the above example, this happens when max(a, b) > c. /// /// For some values of the symbols, the lexmin may be unbounded. - /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate + /// `SymbolicLexOpt` stores these parts of the symbolic domain in a separate /// `PresburgerSet`, `unboundedDomain`. - SymbolicLexMin findSymbolicIntegerLexMin() const; + SymbolicLexOpt findSymbolicIntegerLexMin() const; + + /// Compute the symbolic integer lexmax of the relation. + /// + /// This finds, for every assignment to the symbols and domain, + /// the lexicographically maximum value attained by the range. + /// + /// For example, the symbolic lexmax of the set + /// + /// (x, y)[a, b, c] : (a <= x, x <= b, x <= c) + /// + /// can be written as + /// + /// x = b if b <= c, a <= b + /// x = c if c < b, a <= c + /// + /// This function is stored in the `lexopt` function in the result. + /// Some assignments to the symbols might make the set empty. + /// Such points are not part of the function's domain. + /// In the above example, this happens when min(b, c) < a. + /// + /// For some values of the symbols, the lexmax may be unbounded. + /// `SymbolicLexOpt` stores these parts of the symbolic domain in a separate + /// `PresburgerSet`, `unboundedDomain`. + SymbolicLexOpt findSymbolicIntegerLexMax() const; /// Return the set difference of this set and the given set, i.e., /// return `this \ set`. diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -43,7 +43,7 @@ /// these constraints that are redundant, i.e. a subset of constraints that /// doesn't constrain the affine set further after adding the non-redundant /// constraints. The LexSimplex class provides support for computing the -/// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class +/// lexicographic minimum of an IntegerRelation. The SymbolicLexOpt class /// provides support for computing symbolic lexicographic minimums. All of these /// classes can be constructed from an IntegerRelation, and all inherit common /// functionality from SimplexBase. @@ -530,15 +530,15 @@ }; /// Represents the result of a symbolic lexicographic minimization computation. -struct SymbolicLexMin { - SymbolicLexMin(const PresburgerSpace &space) - : lexmin(space), +struct SymbolicLexOpt { + SymbolicLexOpt(const PresburgerSpace &space) + : lexopt(space), unboundedDomain(PresburgerSet::getEmpty(space.getDomainSpace())) {} /// This maps assignments of symbols to the corresponding lexmin. /// Takes no value when no integer sample exists for the assignment or if the /// lexmin is unbounded. - PWMAFunction lexmin; + PWMAFunction lexopt; /// Contains all assignments to the symbols that made the lexmin unbounded. /// Note that the symbols of the input set to the symbolic lexmin are dims /// of this PrebsurgerSet. @@ -581,7 +581,7 @@ /// `constraints`, and no other vars. `isSymbol` specifies which vars of /// `constraints` should be considered as symbols. /// - /// The resulting SymbolicLexMin's space will be compatible with that of + /// The resulting SymbolicLexOpt's space will be compatible with that of /// symbolDomain. SymbolicLexSimplex(const IntegerRelation &constraints, const IntegerPolyhedron &symbolDomain, @@ -622,7 +622,7 @@ /// /// The spaces of the sets in the result are compatible with the symbolDomain /// passed in the SymbolicLexSimplex constructor. - SymbolicLexMin computeSymbolicIntegerLexMin(); + SymbolicLexOpt computeSymbolicIntegerLexMin(); private: /// Perform all pivots that do not require branching. @@ -670,7 +670,7 @@ /// Record a lexmin. The tableau must be consistent with all variables /// having symbolic samples with integer coefficients. - void recordOutput(SymbolicLexMin &result) const; + void recordOutput(SymbolicLexOpt &result) const; /// The symbol domain. IntegerPolyhedron domainPoly; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -172,7 +172,7 @@ return PresburgerRelation(*this); // Move all the non-div locals to the end, as the current API to - // SymbolicLexMin requires these to form a contiguous range. + // SymbolicLexOpt requires these to form a contiguous range. // // Take a copy so we can perform mutations. IntegerRelation copy = *this; @@ -211,13 +211,13 @@ // exists, which is the union of the domain of the returned lexmin function // and the returned set of assignments to the "symbols" that makes the lexmin // unbounded. - SymbolicLexMin lexminResult = + SymbolicLexOpt lexminResult = SymbolicLexSimplex(copy, /*symbolOffset*/ 0, IntegerPolyhedron(PresburgerSpace::getSetSpace( /*numDims=*/copy.getNumVars() - numNonDivLocals))) .computeSymbolicIntegerLexMin(); PresburgerRelation result = - lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain); + lexminResult.lexopt.getDomain().unionSet(lexminResult.unboundedDomain); // The result set might lie in the wrong space -- all its ids are dims. // Set it to the desired space and return. @@ -227,7 +227,7 @@ return result; } -SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const { +SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMin() const { // Symbol and Domain vars will be used as symbols for symbolic lexmin. // In other words, for every value of the symbols and domain, return the // lexmin value of the (range, locals). @@ -239,7 +239,7 @@ // Compute the symbolic lexmin of the dims and locals, with the symbols being // the actual symbols of this set. // The resultant space of lexmin is the space of the relation itself. - SymbolicLexMin result = + SymbolicLexOpt result = SymbolicLexSimplex(*this, IntegerPolyhedron(PresburgerSpace::getSetSpace( /*numDims=*/getNumDomainVars(), @@ -249,11 +249,42 @@ // We want to return only the lexmin over the dims, so strip the locals from // the computed lexmin. - result.lexmin.removeOutputs(result.lexmin.getNumOutputs() - getNumLocalVars(), - result.lexmin.getNumOutputs()); + result.lexopt.removeOutputs(result.lexopt.getNumOutputs() - getNumLocalVars(), + result.lexopt.getNumOutputs()); return result; } +SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMax() const { + IntegerRelation flippedRel = *this; + // Flip range sign by flipping the sign of range variables in all constraints. + for (unsigned j = getNumDomainVars(), + b = getNumDomainVars() + getNumRangeVars(); + j < b; j++) { + for (unsigned i = 0, a = getNumEqualities(); i < a; i++) + flippedRel.atEq(i, j) = -1 * atEq(i, j); + for (unsigned i = 0, a = getNumInequalities(); i < a; i++) + flippedRel.atIneq(i, j) = -1 * atIneq(i, j); + } + // Compute negated lexmax by computing lexmin. + SymbolicLexOpt flippedSymbolicIntegerLexMax = + flippedRel.findSymbolicIntegerLexMin(), + symbolicIntegerLexMax( + flippedSymbolicIntegerLexMax.lexopt.getSpace()); + // Get lexmax by again flipping range sign in the PWMA constraints. + for (auto &flippedPiece : + flippedSymbolicIntegerLexMax.lexopt.getAllPieces()) { + Matrix mat = flippedPiece.output.getOutputMatrix(); + for (unsigned i = 0, e = mat.getNumRows(); i < e; i++) + mat.negateRow(i); + MultiAffineFunction maf(flippedPiece.output.getSpace(), mat); + PWMAFunction::Piece piece = {flippedPiece.domain, maf}; + symbolicIntegerLexMax.lexopt.addPiece(piece); + } + symbolicIntegerLexMax.unboundedDomain = + flippedSymbolicIntegerLexMax.unboundedDomain; + return symbolicIntegerLexMax; +} + PresburgerRelation IntegerRelation::subtract(const PresburgerRelation &set) const { return PresburgerRelation(*this).subtract(set); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -435,9 +435,9 @@ return moveRowUnknownToColumn(cutRow); } -void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const { +void SymbolicLexSimplex::recordOutput(SymbolicLexOpt &result) const { Matrix output(0, domainPoly.getNumVars() + 1); - output.reserveRows(result.lexmin.getNumOutputs()); + output.reserveRows(result.lexopt.getNumOutputs()); for (const Unknown &u : var) { if (u.isSymbol) continue; @@ -469,10 +469,10 @@ } // Store the output in a MultiAffineFunction and add it the result. - PresburgerSpace funcSpace = result.lexmin.getSpace(); + PresburgerSpace funcSpace = result.lexopt.getSpace(); funcSpace.insertVar(VarKind::Local, 0, domainPoly.getNumLocalVars()); - result.lexmin.addPiece( + result.lexopt.addPiece( {PresburgerSet(domainPoly), MultiAffineFunction(funcSpace, output, domainPoly.getLocalReprs())}); } @@ -515,8 +515,8 @@ return success(); } -SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { - SymbolicLexMin result(PresburgerSpace::getRelationSpace( +SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() { + SymbolicLexOpt result(PresburgerSpace::getRelationSpace( /*numDomain=*/domainPoly.getNumDimVars(), /*numRange=*/var.size() - nSymbol, /*numSymbols=*/domainPoly.getNumSymbolVars())); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -1198,13 +1198,13 @@ ASSERT_NE(poly.getNumDimVars(), 0u); ASSERT_NE(poly.getNumSymbolVars(), 0u); - SymbolicLexMin result = poly.findSymbolicIntegerLexMin(); + SymbolicLexOpt result = poly.findSymbolicIntegerLexMin(); if (expectedLexminRepr.empty()) { - EXPECT_TRUE(result.lexmin.getDomain().isIntegerEmpty()); + EXPECT_TRUE(result.lexopt.getDomain().isIntegerEmpty()); } else { PWMAFunction expectedLexmin = parsePWMAF(expectedLexminRepr); - EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin)); + EXPECT_TRUE(result.lexopt.isEqual(expectedLexmin)); } if (expectedUnboundedDomainRepr.empty()) { diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -126,7 +126,7 @@ } TEST(IntegerRelationTest, symbolicLexmin) { - SymbolicLexMin lexmin = + SymbolicLexOpt lexmin = parseRelationFromSet("(a, x)[b] : (x - a >= 0, x - b >= 0)", 1) .findSymbolicIntegerLexMin(); @@ -135,5 +135,18 @@ {"(a)[b] : (b - a - 1 >= 0)", "(a)[b] -> (b)"}, // b }); EXPECT_TRUE(lexmin.unboundedDomain.isIntegerEmpty()); - EXPECT_TRUE(lexmin.lexmin.isEqual(expectedLexmin)); + EXPECT_TRUE(lexmin.lexopt.isEqual(expectedLexmin)); +} + +TEST(IntegerRelationTest, symbolicLexmax) { + SymbolicLexOpt lexmax = + parseRelationFromSet("(a, x)[b] : (a - x >= 0, b - x >= 0)", 1) + .findSymbolicIntegerLexMax(); + + PWMAFunction expectedLexmax = parsePWMAF({ + {"(a)[b] : (a - b >= 0)", "(a)[b] -> (b)"}, + {"(a)[b] : (b - a - 1 >= 0)", "(a)[b] -> (a)"}, + }); + EXPECT_TRUE(lexmax.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmax.lexopt.isEqual(expectedLexmax)); }