diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -28,6 +28,7 @@ class IntegerPolyhedron; class PresburgerSet; class PresburgerRelation; +struct SymbolicLexMin; /// An IntegerRelation represents the set of points from a PresburgerSpace that /// satisfy a list of affine constraints. Affine constraints can be inequalities @@ -583,6 +584,30 @@ /// union of convex disjuncts. PresburgerRelation computeReprWithOnlyDivLocals() const; + /// Compute the symbolic integer lexmin of the relation. + /// + /// This finds, for every assignment to the symbols and domain, + /// the lexicographically minimum value attained by the range. + /// + /// For example, the symbolic lexmin of the set + /// + /// (x, y)[a, b, c] : (a <= x, b <= x, x <= c) + /// + /// can be written as + /// + /// x = a if b <= a, a <= c + /// x = b if a < b, b <= c + /// + /// This function is stored in the `lexmin` function in the result. + /// Some assignments to the symbols might make the set empty. + /// Such points are not part of the function's domain. + /// In the above example, this happens when max(a, b) > c. + /// + /// For some values of the symbols, the lexmin may be unbounded. + /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate + /// `PresburgerSet`, `unboundedDomain`. + SymbolicLexMin findSymbolicIntegerLexMin() const; + void print(raw_ostream &os) const; void dump() const; @@ -692,8 +717,6 @@ Matrix inequalities; }; -struct SymbolicLexMin; - /// An IntegerPolyhedron represents the set of points from a PresburgerSpace /// that satisfy a list of affine constraints. Affine constraints can be /// inequalities or equalities in the form: @@ -767,28 +790,6 @@ /// column position (i.e., not relative to the kind of variable) of the /// first added variable. unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override; - - /// Compute the symbolic integer lexmin of the polyhedron. - /// This finds, for every assignment to the symbols, the lexicographically - /// minimum value attained by the dimensions. For example, the symbolic lexmin - /// of the set - /// - /// (x, y)[a, b, c] : (a <= x, b <= x, x <= c) - /// - /// can be written as - /// - /// x = a if b <= a, a <= c - /// x = b if a < b, b <= c - /// - /// This function is stored in the `lexmin` function in the result. - /// Some assignments to the symbols might make the set empty. - /// Such points are not part of the function's domain. - /// In the above example, this happens when max(a, b) > c. - /// - /// For some values of the symbols, the lexmin may be unbounded. - /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate - /// `PresburgerSet`, `unboundedDomain`. - SymbolicLexMin findSymbolicIntegerLexMin() const; }; } // namespace presburger diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -226,12 +226,23 @@ return result; } -SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const { +SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const { + // Symbol and Domain vars will be used as symbols for symbolic lexmin. + // In other words, for every value of the symbols and domain, return the + // lexmin value of the (range, locals). + llvm::SmallBitVector isSymbol(getNumVars(), false); + isSymbol.set(getVarKindOffset(VarKind::Symbol), + getVarKindEnd(VarKind::Symbol)); + isSymbol.set(getVarKindOffset(VarKind::Domain), + getVarKindEnd(VarKind::Domain)); // Compute the symbolic lexmin of the dims and locals, with the symbols being // the actual symbols of this set. SymbolicLexMin result = - SymbolicLexSimplex(*this, IntegerPolyhedron(PresburgerSpace::getSetSpace( - /*numDims=*/getNumSymbolVars()))) + SymbolicLexSimplex(*this, + IntegerPolyhedron(PresburgerSpace::getSetSpace( + /*numDims=*/getNumDomainVars(), + /*numSymbols=*/getNumSymbolVars())), + isSymbol) .computeSymbolicIntegerLexMin(); // We want to return only the lexmin over the dims, so strip the locals from diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -1170,10 +1170,11 @@ PWMAFunction expectedLexmin = parsePWMAF(/*numInputs=*/poly.getNumSymbolVars(), - /*numOutputs=*/poly.getNumDimVars(), expectedLexminRepr); + /*numOutputs=*/poly.getNumDimVars(), expectedLexminRepr, + /*numSymbols=*/poly.getNumSymbolVars()); PresburgerSet expectedUnboundedDomain = parsePresburgerSetFromPolyStrings( - poly.getNumSymbolVars(), expectedUnboundedDomainRepr); + /*numDims=*/0, expectedUnboundedDomainRepr, poly.getNumSymbolVars()); SymbolicLexMin result = poly.findSymbolicIntegerLexMin(); @@ -1200,114 +1201,116 @@ TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) { expectSymbolicIntegerLexMin("(x)[a] : (x - a >= 0)", { - {"(a) : ()", {{1, 0}}}, // a + {"()[a] : ()", {{1, 0}}}, // a }); expectSymbolicIntegerLexMin( "(x)[a, b] : (x - a >= 0, x - b >= 0)", { - {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a - {"(a, b) : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b + {"()[a, b] : (a - b >= 0)", {{1, 0, 0}}}, // a + {"()[a, b] : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b }); expectSymbolicIntegerLexMin( "(x)[a, b, c] : (x -a >= 0, x - b >= 0, x - c >= 0)", { - {"(a, b, c) : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}}, // a - {"(a, b, c) : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}}, // b - {"(a, b, c) : (c - a - 1 >= 0, c - b - 1 >= 0)", {{0, 0, 1, 0}}}, // c + {"()[a, b, c] : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}}, // a + {"()[a, b, c] : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}}, // b + {"()[a, b, c] : (c - a - 1 >= 0, c - b - 1 >= 0)", + {{0, 0, 1, 0}}}, // c }); expectSymbolicIntegerLexMin("(x, y)[a] : (x - a >= 0, x + y >= 0)", { - {"(a) : ()", {{1, 0}, {-1, 0}}}, // (a, -a) + {"()[a] : ()", {{1, 0}, {-1, 0}}}, // (a, -a) }); expectSymbolicIntegerLexMin( "(x, y)[a] : (x - a >= 0, x + y >= 0, y >= 0)", { - {"(a) : (a >= 0)", {{1, 0}, {0, 0}}}, // (a, 0) - {"(a) : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a) + {"()[a] : (a >= 0)", {{1, 0}, {0, 0}}}, // (a, 0) + {"()[a] : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a) }); expectSymbolicIntegerLexMin( "(x, y)[a, b, c] : (x - a >= 0, y - b >= 0, c - x - y >= 0)", { - {"(a, b, c) : (c - a - b >= 0)", + {"()[a, b, c] : (c - a - b >= 0)", {{1, 0, 0, 0}, {0, 1, 0, 0}}}, // (a, b) }); expectSymbolicIntegerLexMin( "(x, y, z)[a, b, c] : (c - z >= 0, b - y >= 0, x + y + z - a == 0)", { - {"(a, b, c) : ()", + {"()[a, b, c] : ()", {{1, -1, -1, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}}}, // (a - b - c, b, c) }); expectSymbolicIntegerLexMin( "(x)[a, b] : (a >= 0, b >= 0, x >= 0, a + b + x - 1 >= 0)", { - {"(a, b) : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0 - {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1 + {"()[a, b] : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0 + {"()[a, b] : (a == 0, b == 0)", {{0, 0, 1}}}, // 1 }); expectSymbolicIntegerLexMin( "(x)[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, 1 - x >= 0, x >= " "0, a + b + x - 1 >= 0)", { - {"(a, b) : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= 0)", - {{0, 0, 0}}}, // 0 - {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1 + {"()[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= " + "0)", + {{0, 0, 0}}}, // 0 + {"()[a, b] : (a == 0, b == 0)", {{0, 0, 1}}}, // 1 }); expectSymbolicIntegerLexMin( "(x, y, z)[a, b] : (x - a == 0, y - b == 0, x >= 0, y >= 0, z >= 0, x + " "y + z - 1 >= 0)", { - {"(a, b) : (a >= 0, b >= 0, 1 - a - b >= 0)", + {"()[a, b] : (a >= 0, b >= 0, 1 - a - b >= 0)", {{1, 0, 0}, {0, 1, 0}, {-1, -1, 1}}}, // (a, b, 1 - a - b) - {"(a, b) : (a >= 0, b >= 0, a + b - 2 >= 0)", + {"()[a, b] : (a >= 0, b >= 0, a + b - 2 >= 0)", {{1, 0, 0}, {0, 1, 0}, {0, 0, 0}}}, // (a, b, 0) }); expectSymbolicIntegerLexMin("(x)[a, b] : (x - a == 0, x - b >= 0)", { - {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a + {"()[a, b] : (a - b >= 0)", {{1, 0, 0}}}, // a }); expectSymbolicIntegerLexMin( "(q)[a] : (a - 1 - 3*q == 0, q >= 0)", { - {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", + {"()[a] : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", {{0, 1, 0}}}, // a floordiv 3 }); expectSymbolicIntegerLexMin( "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 1 - r >= 0, r >= 0)", { - {"(a) : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)", + {"()[a] : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)", {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3) - {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", + {"()[a] : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", {{0, 0, 1}, {0, 1, 0}}}, // (1 a floordiv 3) }); expectSymbolicIntegerLexMin( "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 2 - r >= 0, r - 1 >= 0)", { - {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", + {"()[a] : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3) - {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)", + {"()[a] : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)", {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3) }); expectSymbolicIntegerLexMin( "(r, q)[a] : (a - r - 3*q == 0, q >= 0, r >= 0)", { - {"(a) : (a - 3*(a floordiv 3) == 0, a >= 0)", + {"()[a] : (a - 3*(a floordiv 3) == 0, a >= 0)", {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3) - {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", + {"()[a] : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3) - {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)", + {"()[a] : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)", {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3) }); @@ -1323,11 +1326,11 @@ // What's the lexmin solution using exactly g true vars? "g - x - y - z - w == 0)", { - {"(g) : (g - 1 == 0)", + {"()[g] : (g - 1 == 0)", {{0, 0}, {0, 1}, {0, 0}, {0, 0}}}, // (0, 1, 0, 0) - {"(g) : (g - 2 == 0)", + {"()[g] : (g - 2 == 0)", {{0, 0}, {0, 0}, {0, 1}, {0, 1}}}, // (0, 0, 1, 1) - {"(g) : (g - 3 == 0)", + {"()[g] : (g - 3 == 0)", {{0, 0}, {0, 1}, {0, 1}, {0, 1}}}, // (0, 1, 1, 1) }); @@ -1340,11 +1343,11 @@ // According to Bezout's lemma, 14x + 35y can take on all multiples // of 7 and no other values. So the solution exists iff r - a is a // multiple of 7. - {"(a, r) : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"}); + {"()[a, r] : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"}); // The lexmins are unbounded. expectSymbolicIntegerLexMin("(x, y)[a] : (9*x - 4*y - 2*a >= 0)", {}, - {"(a) : ()"}); + {"()[a] : ()"}); // Test cases adapted from isl. expectSymbolicIntegerLexMin( @@ -1352,7 +1355,7 @@ // So b is minimized when c = b. "(b, c)[a] : (a - 4*b + 2*c == 0, c - b >= 0)", { - {"(a) : (a - 2*(a floordiv 2) == 0)", + {"()[a] : (a - 2*(a floordiv 2) == 0)", {{0, 1, 0}, {0, 1, 0}}}, // (a floordiv 2, a floordiv 2) }); @@ -1362,7 +1365,7 @@ "(b)[a] : (255 - b >= 0, b >= 0, a - 512*b - 1 >= 0, 512*b -a + 509 >= " "0, b + 7 - 16*((8 + b) floordiv 16) >= 0)", { - {"(a) : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv " + {"()[a] : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv " "512) - 1 >= 0, 512*(a floordiv 512) - a + 509 >= 0, (a floordiv " "512) + 7 - 16*((8 + (a floordiv 512)) floordiv 16) >= 0)", {{0, 1, 0, 0}}}, // (a floordiv 2, a floordiv 2) @@ -1375,7 +1378,8 @@ "2*N - 3*K + a - b >= 0, 4*N - K + 1 - 3*b >= 0, b - N >= 0, a - x - 1 " ">= 0)", {{ - "(K, N, x, y) : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + N " + "()[K, N, x, y] : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + " + "N " ">= 0, N + K - 2 - x >= 0, x - 4 >= 0)", {{0, 0, 1, 0, 1}, {0, 1, 0, 0, 0}} // (1 + x, N) }}); diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "./Utils.h" +#include "mlir/Analysis/Presburger/Simplex.h" #include #include @@ -122,3 +123,20 @@ EXPECT_TRUE(map1.isEqual(map3)); } } + +TEST(IntegerRelationTest, symbolicLexmin) { + SymbolicLexMin lexmin = + parseRelationFromSet("(a, x)[b] : (x - a >= 0, x - b >= 0)", 1) + .findSymbolicIntegerLexMin(); + + PWMAFunction expectedLexmin = + parsePWMAF(/*numInputs=*/2, + /*numOutputs=*/1, + { + {"(a)[b] : (a - b >= 0)", {{1, 0, 0}}}, // a + {"(a)[b] : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b + }, + /*numSymbols=*/1); + EXPECT_TRUE(lexmin.unboundedDomain.isIntegerEmpty()); + EXPECT_TRUE(lexmin.lexmin.isEqual(expectedLexmin)); +} diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h --- a/mlir/unittests/Analysis/Presburger/Utils.h +++ b/mlir/unittests/Analysis/Presburger/Utils.h @@ -17,6 +17,7 @@ #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" +#include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" @@ -40,9 +41,10 @@ /// are all valid IntegerSet representation and that all of them have the same /// number of dimensions as is specified by the numDims argument. inline PresburgerSet -parsePresburgerSetFromPolyStrings(unsigned numDims, ArrayRef strs) { - PresburgerSet set = - PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(numDims)); +parsePresburgerSetFromPolyStrings(unsigned numDims, ArrayRef strs, + unsigned numSymbols = 0) { + PresburgerSet set = PresburgerSet::getEmpty( + PresburgerSpace::getSetSpace(numDims, numSymbols)); for (StringRef str : strs) set.unionInPlace(parsePoly(str)); return set; @@ -71,9 +73,9 @@ unsigned numSymbols = 0) { static MLIRContext context; - PWMAFunction result( - PresburgerSpace::getSetSpace(numInputs - numSymbols, numSymbols), - numOutputs); + PWMAFunction result(PresburgerSpace::getSetSpace( + /*numDims=*/numInputs - numSymbols, numSymbols), + numOutputs); for (const auto &pair : data) { IntegerPolyhedron domain = parsePoly(pair.first);