diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h --- a/mlir/include/mlir/Analysis/Presburger/Fraction.h +++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h @@ -25,7 +25,7 @@ /// representable by 64-bit integers. struct Fraction { /// Default constructor initializes the represented rational number to zero. - Fraction() {} + Fraction() = default; /// Construct a Fraction from a numerator and denominator. Fraction(int64_t oNum, int64_t oDen) : num(oNum), den(oDen) { @@ -35,6 +35,13 @@ } } + // Return the value of the fraction as an integer. This should only be called + // when the fraction's value is really an integer. + int64_t getAsInteger() const { + assert(num % den == 0 && "Get as integer called on non-integral fraction!"); + return num / den; + } + /// The numerator and denominator, respectively. The denominator is always /// positive. int64_t num{0}, den{1}; diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h @@ -212,6 +212,13 @@ presburger_utils::MaybeOptimum> getRationalLexMin() const; + /// Same as above, but returns lexicographically minimal integer point. + /// Note: this should be used only when the lexmin is really required. + /// For a generic integer sampling operation, findIntegerSample is more + /// robust and should be preferred. + presburger_utils::MaybeOptimum> + getIntegerLexMin() const; + /// Swap the posA^th identifier with the posB^th identifier. virtual void swapId(unsigned posA, unsigned posB); diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -265,6 +265,10 @@ /// Returns the unknown associated with row. Unknown &unknownFromRow(unsigned row); + /// Add a new row to the tableau and the associated data structures. The row + /// is initialized to zero. + unsigned addZeroRow(bool makeRestricted = false); + /// Add a new row to the tableau and the associated data structures. /// The new row is considered to be a constraint; the new Unknown lives in /// con. @@ -436,6 +440,12 @@ /// Return the lexicographically minimum rational solution to the constraints. presburger_utils::MaybeOptimum> getRationalLexMin(); + /// Return the lexicographically minimum integer solution to the constraints. + /// + /// Note: this should be used only when the lexmin is really needed. To obtain + /// any integer sample, use Simplex::findIntegerSample as that is more robust. + presburger_utils::MaybeOptimum> getIntegerLexMin(); + protected: /// Returns the current sample point, which may contain non-integer (rational) /// coordinates. Returns an empty optimum when the tableau is empty. @@ -446,6 +456,15 @@ presburger_utils::MaybeOptimum> getRationalSample() const; + /// Given a row that has a non-integer sample value, add an inequality such + /// that this fractional sample value is cut away from the polytope. The added + /// inequality will be such that no integer points are removed. + /// + /// Returns whether the cut constraint could be enforced, i.e. failure if the + /// cut made the polytope empty, and success if it didn't. Failure status + /// indicates that the polytope didn't have any integer points. + LogicalResult addCut(unsigned row); + /// Undo the addition of the last constraint. This is only called while /// rolling back. void undoLastConstraint() final; @@ -460,6 +479,10 @@ /// Otherwise, return an empty optional. Optional maybeGetViolatedRow() const; + /// Get a row corresponding to a var that has a non-integral sample value, if + /// one exists. Otherwise, return an empty optional. + Optional maybeGetNonIntegeralVarRow() const; + /// Given two potential pivot columns for a row, return the one that results /// in the lexicographically smallest sample vector. unsigned getLexMinPivotColumn(unsigned row, unsigned colA, diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -92,6 +92,26 @@ return maybeLexMin; } +MaybeOptimum> +IntegerPolyhedron::getIntegerLexMin() const { + assert(getNumSymbolIds() == 0 && "Symbols are not supported!"); + MaybeOptimum> maybeLexMin = + LexSimplex(*this).getIntegerLexMin(); + + if (!maybeLexMin.isBounded()) + return maybeLexMin.getKind(); + + // The Simplex returns the lexmin over all the variables including locals. But + // locals are not actually part of the space and should not be returned in the + // result. Since the locals are placed last in the list of identifiers, they + // will be minimized last in the lexmin. So simply truncating out the locals + // from the end of the answer gives the desired lexmin over the dimensions. + assert(maybeLexMin->size() == getNumIds() && + "Incorrect number of vars in lexMin!"); + maybeLexMin->resize(getNumDimAndSymbolIds()); + return maybeLexMin; +} + unsigned IntegerPolyhedron::insertDimId(unsigned pos, unsigned num) { return insertId(IdKind::SetDim, pos, num); } diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -59,13 +59,7 @@ return unknownFromIndex(rowUnknown[row]); } -/// Add a new row to the tableau corresponding to the given constant term and -/// list of coefficients. The coefficients are specified as a vector of -/// (variable index, coefficient) pairs. -unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { - assert(coeffs.size() == var.size() + 1 && - "Incorrect number of coefficients!"); - +unsigned SimplexBase::addZeroRow(bool makeRestricted) { ++nRow; // If the tableau is not big enough to accomodate the extra row, we extend it. if (nRow >= tableau.getNumRows()) @@ -77,6 +71,17 @@ tableau.fillRow(nRow - 1, 0); tableau(nRow - 1, 0) = 1; + return con.size() - 1; +} + +/// Add a new row to the tableau corresponding to the given constant term and +/// list of coefficients. The coefficients are specified as a vector of +/// (variable index, coefficient) pairs. +unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { + assert(coeffs.size() == var.size() + 1 && + "Incorrect number of coefficients!"); + + addZeroRow(makeRestricted); tableau(nRow - 1, 1) = coeffs.back(); if (usingBigM) { // When the lexicographic pivot rule is used, instead of the variables @@ -164,6 +169,56 @@ return getRationalSample(); } +LogicalResult LexSimplex::addCut(unsigned row) { + int64_t denom = tableau(row, 0); + addZeroRow(/*makeRestricted=*/true); + tableau(nRow - 1, 0) = denom; + tableau(nRow - 1, 1) = -mod(-tableau(row, 1), denom); + tableau(nRow - 1, 2) = 0; // M has all factors in it. + for (unsigned col = 3; col < nCol; ++col) + tableau(nRow - 1, col) = mod(tableau(row, col), denom); + return moveRowUnknownToColumn(nRow - 1); +} + +Optional LexSimplex::maybeGetNonIntegeralVarRow() const { + for (const Unknown &u : var) { + if (u.orientation == Orientation::Column) + continue; + // If the sample value is of the form (a/d)M + b/d, we need b to be + // divisible by d. We assume M is very large and contains all possible + // factors and is divisible by everything. + unsigned row = u.pos; + if (tableau(row, 1) % tableau(row, 0) != 0) + return row; + } + return {}; +} + +MaybeOptimum> LexSimplex::getIntegerLexMin() { + while (!empty) { + restoreRationalConsistency(); + if (empty) + return OptimumKind::Empty; + + if (Optional maybeRow = maybeGetNonIntegeralVarRow()) { + // Failure occurs when the polytope is integer empty. + if (failed(addCut(*maybeRow))) + return OptimumKind::Empty; + continue; + } + + MaybeOptimum> sample = getRationalSample(); + assert(!sample.isEmpty() && "If we reached here the sample should exist!"); + if (sample.isUnbounded()) + return OptimumKind::Unbounded; + return llvm::to_vector<8>(llvm::map_range( + *sample, [](const Fraction &f) { return f.getAsInteger(); })); + } + + // Polytope is integer empty. + return OptimumKind::Empty; +} + bool LexSimplex::rowIsViolated(unsigned row) const { if (tableau(row, 2) < 0) return true; diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/Presburger/IntegerPolyhedron.h" #include "./Utils.h" +#include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/IR/MLIRContext.h" #include @@ -36,29 +37,53 @@ return set; } +static void dump(ArrayRef vec) { + for (int64_t x : vec) + llvm::errs() << x << ' '; + llvm::errs() << '\n'; +} + /// If fn is TestFunction::Sample (default): -/// If hasSample is true, check that findIntegerSample returns a valid sample -/// for the IntegerPolyhedron poly. -/// If hasSample is false, check that findIntegerSample returns None. +/// +/// If hasSample is true, check that findIntegerSample returns a valid sample +/// for the IntegerPolyhedron poly. Also check that getIntegerLexmin finds a +/// non-empty lexmin. +/// +/// If hasSample is false, check that findIntegerSample returns None and +/// getIntegerLexMin returns Empty. /// /// If fn is TestFunction::Empty, check that isIntegerEmpty returns the /// opposite of hasSample. static void checkSample(bool hasSample, const IntegerPolyhedron &poly, TestFunction fn = TestFunction::Sample) { Optional> maybeSample; + MaybeOptimum> maybeLexMin; switch (fn) { case TestFunction::Sample: maybeSample = poly.findIntegerSample(); + maybeLexMin = poly.getIntegerLexMin(); + if (!hasSample) { EXPECT_FALSE(maybeSample.hasValue()); if (maybeSample.hasValue()) { - for (auto x : *maybeSample) - llvm::errs() << x << ' '; - llvm::errs() << '\n'; + llvm::errs() << "findIntegerSample gave sample: "; + dump(*maybeSample); + } + + EXPECT_TRUE(maybeLexMin.isEmpty()); + if (maybeLexMin.isBounded()) { + llvm::errs() << "getIntegerLexMin gave sample: "; + dump(*maybeLexMin); } } else { ASSERT_TRUE(maybeSample.hasValue()); EXPECT_TRUE(poly.containsPoint(*maybeSample)); + + ASSERT_FALSE(maybeLexMin.isEmpty()); + if (maybeLexMin.isUnbounded()) + EXPECT_TRUE(Simplex(poly).isUnbounded()); + if (maybeLexMin.isBounded()) + EXPECT_TRUE(poly.containsPoint(*maybeLexMin)); } break; case TestFunction::Empty: @@ -1138,6 +1163,31 @@ parsePoly("(x) : (2*x >= 0, -x - 1 >= 0)", &context)); } +void expectIntegerLexMin(const IntegerPolyhedron &poly, ArrayRef min) { + auto lexMin = poly.getIntegerLexMin(); + ASSERT_TRUE(lexMin.isBounded()); + EXPECT_EQ(ArrayRef(*lexMin), min); +} + +void expectNoIntegerLexMin(OptimumKind kind, const IntegerPolyhedron &poly) { + ASSERT_NE(kind, OptimumKind::Bounded) + << "Use expectRationalLexMin for bounded min"; + EXPECT_EQ(poly.getRationalLexMin().getKind(), kind); +} + +TEST(IntegerPolyhedronTest, getIntegerLexMin) { + MLIRContext context; + expectIntegerLexMin(parsePoly("(x, y, z) : (2*x + 13 >= 0, 4*y - 3*x - 2 >= " + "0, 11*z + 5*y - 3*x + 7 >= 0)", + &context), + {-6, -4, 0}); + // Similar to above but no lower bound on z. + expectNoIntegerLexMin(OptimumKind::Unbounded, + parsePoly("(x, y, z) : (2*x + 13 >= 0, 4*y - 3*x - 2 " + ">= 0, -11*z + 5*y - 3*x + 7 >= 0)", + &context)); +} + static void expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly, Optional trueVolume,