diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -185,6 +185,10 @@ /// Add new variables to the end of the list of variables. void appendVariable(unsigned count = 1); + /// Append a new variable to the simplex and constrain it such that its only + /// integer value is the floor div of `coeffs` and `denom`. + void addDivisionVariable(ArrayRef coeffs, int64_t denom); + /// Mark the tableau as being empty. void markEmpty(); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -818,6 +818,28 @@ } } +/// We add the usual floor division constraints: +/// `0 <= coeffs - denom*q <= denom - 1`, where `q` is the new division +/// variable. +/// +/// This constrains the remainder `coeffs - denom*q` to be in the +/// range `[0, denom - 1]`, which fixes the integer value of the quotient `q`. +void SimplexBase::addDivisionVariable(ArrayRef coeffs, int64_t denom) { + assert(denom != 0 && "Cannot divide by zero!\n"); + appendVariable(); + + SmallVector ineq(coeffs.begin(), coeffs.end()); + int64_t constTerm = ineq.back(); + ineq.back() = -denom; + ineq.push_back(constTerm); + addInequality(ineq); + + for (int64_t &coeff : ineq) + coeff = -coeff; + ineq.back() += denom - 1; + addInequality(ineq); +} + void SimplexBase::appendVariable(unsigned count) { if (count == 0) return; diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp --- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp +++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp @@ -538,3 +538,13 @@ EXPECT_TRUE(sim2.isRationalSubsetOf(s2)); EXPECT_FALSE(sim2.isRationalSubsetOf(empty)); } + +TEST(SimplexTest, addDivisionVariable) { + Simplex simplex(/*nVar=*/1); + simplex.addDivisionVariable({1, 0}, 2); + simplex.addInequality({1, 0, -3}); // x >= 3. + simplex.addInequality({-1, 0, 9}); // x <= 9. + Optional> sample = simplex.findIntegerSample(); + ASSERT_TRUE(sample.hasValue()); + EXPECT_EQ((*sample)[0] / 2, (*sample)[1]); +}