diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -118,6 +118,16 @@ /// Same as compose, provided for uniformity with applyDomain. void applyRange(const PresburgerRelation &rel); + /// Compute the symbolic integer lexmin of the relation, i.e. for every + /// assignment of the symbols and domain the lexicographically minimum value + /// attained by the range. + SymbolicLexOpt findSymbolicIntegerLexMin() const; + + /// Compute the symbolic integer lexmax of the relation, i.e. for every + /// assignment of the symbols and domain the lexicographically maximum value + /// attained by the range. + SymbolicLexOpt findSymbolicIntegerLexMax() const; + /// Return true if the set contains the given point, and false otherwise. bool containsPoint(ArrayRef point) const; bool containsPoint(ArrayRef point) const { diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Utils.h" #include "llvm/ADT/STLExtras.h" @@ -185,6 +186,38 @@ compose(rel); } +SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMin() const { + PWMAFunction lexopt(PresburgerSpace::getRelationSpace( + getNumDomainVars(), getNumRangeVars(), getNumSymbolVars())); + PresburgerSet unboundedDomain = PresburgerSet::getUniverse( + PresburgerSpace::getSetSpace(getNumDomainVars(), getNumSymbolVars())); + for (const IntegerRelation &cs : disjuncts) { + SymbolicLexOpt s = cs.findSymbolicIntegerLexMin(); + lexopt = lexopt.unionLexMin(s.lexopt); + unboundedDomain = unboundedDomain.intersect(s.unboundedDomain); + } + SymbolicLexOpt lexmin(space); + lexmin.lexopt = lexopt; + lexmin.unboundedDomain = unboundedDomain; + return lexmin; +} + +SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMax() const { + PWMAFunction lexopt(PresburgerSpace::getRelationSpace( + getNumDomainVars(), getNumRangeVars(), getNumSymbolVars())); + PresburgerSet unboundedDomain = PresburgerSet::getUniverse( + PresburgerSpace::getSetSpace(getNumDomainVars(), getNumSymbolVars())); + for (const IntegerRelation &cs : disjuncts) { + SymbolicLexOpt s = cs.findSymbolicIntegerLexMax(); + lexopt = lexopt.unionLexMax(s.lexopt); + unboundedDomain = unboundedDomain.intersect(s.unboundedDomain); + } + SymbolicLexOpt lexmax(space); + lexmax.lexopt = lexopt; + lexmax.unboundedDomain = unboundedDomain; + return lexmax; +} + /// Return the coefficients of the ineq in `rel` specified by `idx`. /// `idx` can refer not only to an actual inequality of `rel`, but also /// to either of the inequalities that make up an equality in `rel`. diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "Parser.h" +#include "mlir/Analysis/Presburger/Simplex.h" #include #include @@ -189,3 +190,39 @@ EXPECT_TRUE(rel.isEqual(inverseRel)); } } + +TEST(IntegerRelationTest, symbolicLexOpt) { + PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet( + {"(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, " + "2 * N - x >= 0, 2 * N - y >= 0)", + "(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, " + "x - N >= 0, M - x >= 0, y - 2 * N >= 0, M - y >= 0)"}, + 1); + + SymbolicLexOpt lexmin = rel.findSymbolicIntegerLexMin(); + + PWMAFunction expectedLexMin = parsePWMAF({ + {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, 2 * N - x " + ">= 0)", + "(x)[N, M] -> (0)"}, + {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, x - 2 * N " + "- 1 >= 0, M - x >= 0)", + "(x)[N, M] -> (2 * N)"}, + }); + + SymbolicLexOpt lexmax = rel.findSymbolicIntegerLexMax(); + + PWMAFunction expectedLexMax = parsePWMAF({ + {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, N - 1 - x " + ">= 0)", + "(x)[N, M] -> (2 * N)"}, + {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, x - N >= " + "0, M - x >= 0)", + "(x)[N, M] -> (M)"}, + }); + + EXPECT_TRUE(lexmin.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmin.lexopt.isEqual(expectedLexMin)); + EXPECT_TRUE(lexmax.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmax.lexopt.isEqual(expectedLexMax)); +}