diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -118,6 +118,19 @@ /// Same as compose, provided for uniformity with applyDomain. void applyRange(const PresburgerRelation &rel); + static SymbolicLexOpt findSymbolicIntegerLexOpt(const PresburgerRelation &rel, + bool min); + + /// Compute the symbolic integer lexmin of the relation, i.e. for every + /// assignment of the symbols and domain the lexicographically minimum value + /// attained by the range. + SymbolicLexOpt findSymbolicIntegerLexMin() const; + + /// Compute the symbolic integer lexmax of the relation, i.e. for every + /// assignment of the symbols and domain the lexicographically maximum value + /// attained by the range. + SymbolicLexOpt findSymbolicIntegerLexMax() const; + /// Return true if the set contains the given point, and false otherwise. bool containsPoint(ArrayRef point) const; bool containsPoint(ArrayRef point) const { diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Utils.h" #include "llvm/ADT/STLExtras.h" @@ -185,6 +186,38 @@ compose(rel); } +SymbolicLexOpt +PresburgerRelation::findSymbolicIntegerLexOpt(const PresburgerRelation &rel, + bool isMin) { + PWMAFunction lexopt(rel.getSpace()); + PresburgerSet unboundedDomain = + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace( + rel.getNumDomainVars(), rel.getNumSymbolVars())); + for (const IntegerRelation &cs : rel.disjuncts) { + SymbolicLexOpt s(rel.getSpace()); + if (isMin) { + s = cs.findSymbolicIntegerLexMin(); + lexopt = lexopt.unionLexMin(s.lexopt); + } else { + s = cs.findSymbolicIntegerLexMax(); + lexopt = lexopt.unionLexMax(s.lexopt); + } + unboundedDomain = unboundedDomain.intersect(s.unboundedDomain); + } + SymbolicLexOpt result(rel.getSpace()); + result.lexopt = lexopt; + result.unboundedDomain = unboundedDomain; + return result; +} + +SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMin() const { + return findSymbolicIntegerLexOpt(*this, true); +} + +SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMax() const { + return findSymbolicIntegerLexOpt(*this, false); +} + /// Return the coefficients of the ineq in `rel` specified by `idx`. /// `idx` can refer not only to an actual inequality of `rel`, but also /// to either of the inequalities that make up an equality in `rel`. diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "Parser.h" +#include "mlir/Analysis/Presburger/Simplex.h" #include #include @@ -189,3 +190,109 @@ EXPECT_TRUE(rel.isEqual(inverseRel)); } } + +TEST(IntegerRelationTest, symbolicLexOpt) { + PresburgerRelation rel1 = parsePresburgerRelationFromPresburgerSet( + {"(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, " + "2 * N - x >= 0, 2 * N - y >= 0)", + "(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, " + "x - N >= 0, M - x >= 0, y - 2 * N >= 0, M - y >= 0)"}, + 1); + + SymbolicLexOpt lexmin1 = rel1.findSymbolicIntegerLexMin(); + + PWMAFunction expectedLexMin1 = parsePWMAF({ + {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, 2 * N - x " + ">= 0)", + "(x)[N, M] -> (0)"}, + {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, x - 2 * N " + "- 1 >= 0, M - x >= 0)", + "(x)[N, M] -> (2 * N)"}, + }); + + SymbolicLexOpt lexmax1 = rel1.findSymbolicIntegerLexMax(); + + PWMAFunction expectedLexMax1 = parsePWMAF({ + {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, N - 1 - x " + ">= 0)", + "(x)[N, M] -> (2 * N)"}, + {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, x - N >= " + "0, M - x >= 0)", + "(x)[N, M] -> (M)"}, + }); + + PresburgerRelation rel2 = parsePresburgerRelationFromPresburgerSet( + // x or y or z + // lexmin = (x, 0, 1 - x) + // lexmax = (x, 1, 1) + {"(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, 1 - z >= " + "0, x + y + z - 1 >= 0)", + // (x or y) and (y or z) and (z or x) + // lexmin = (x, 1 - x, 1) + // lexmax = (x, 1, 1) + "(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, 1 - z >= " + "0, x + y - 1 >= 0, y + z - 1 >= 0, z + x - 1 >= 0)", + // x => (not y) or (not z) + // lexmin = (x, 0, 0) + // lexmax = (x, 1, 1 - x) + "(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, 1 - z >= " + "0, 2 - x - y - z >= 0)"}, + 1); + + SymbolicLexOpt lexmin2 = rel2.findSymbolicIntegerLexMin(); + + PWMAFunction expectedLexMin2 = + parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (0, 0)"}}); + + SymbolicLexOpt lexmax2 = rel2.findSymbolicIntegerLexMax(); + + PWMAFunction expectedLexMax2 = + parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (1, 1)"}}); + + PresburgerRelation rel3 = parsePresburgerRelationFromPresburgerSet( + // (x => u or v or w) and (x or v) and (x or (not w)) + // lexmin = (x, 0, 0, 1 - x) + // lexmax = (x, 1, 1 - x, x) + {"(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, 1 - u >= " + "0, 1 - v >= 0, 1 - w >= 0, -x + u + v + w >= 0, x + v - 1>= 0, x - w " + ">= 0)", + // x => (u => (v => w)) and (x or (not v)) and (x or (not w)) + // lexmin = (x, 0, 0, x) + // lexmax = (x, 1, x, x) + "(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, 1 - u >= " + "0, 1 - v >= 0, 1 - w >= 0, -x - u - v + w + 2 >= 0, x - v >= 0, x - w " + ">= 0)", + // (x or (u or (not v))) and ((not x) or ((not u) or w)) + // and (x or (not v)) and (x or (not w)) + // lexmin = (x, 0, 0, x) + // lexmax = (x, 1, x, x) + "(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, 1 - u >= " + "0, 1 - v >= 0, 1 - w >= 0, x + u - v >= 0, x - u + w >= 0, x - v >= 0, " + "x - w >= 0)"}, + 1); + + SymbolicLexOpt lexmin3 = rel3.findSymbolicIntegerLexMin(); + + PWMAFunction expectedLexMin3 = + parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (0, 0, 0)"}}); + + SymbolicLexOpt lexmax3 = rel3.findSymbolicIntegerLexMax(); + + lexmax3.lexopt.dump(); + + PWMAFunction expectedLexMax3 = + parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (1, 1, x)"}}); + + EXPECT_TRUE(lexmin1.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmin1.lexopt.isEqual(expectedLexMin1)); + EXPECT_TRUE(lexmax1.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmax1.lexopt.isEqual(expectedLexMax1)); + EXPECT_TRUE(lexmin2.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmin2.lexopt.isEqual(expectedLexMin2)); + EXPECT_TRUE(lexmax2.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmax2.lexopt.isEqual(expectedLexMax2)); + EXPECT_TRUE(lexmin3.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmin3.lexopt.isEqual(expectedLexMin3)); + EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexMax3)); +}