diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -357,7 +357,7 @@ isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom); } -/// This proceeds similarly to LexSimplex::addCut(). We are given a row that has +/// This proceeds similarly to LexSimplexBase::addCut(). We are given a row that has /// a symbolic sample value with fractional coefficients. /// /// Let the row be @@ -374,10 +374,11 @@ /// sum_i (b_i%d)y_i = ((-c%d) + sum_i (-a_i%d)s_i)%d + k*d for some integer k /// /// where we take a modulo of the whole symbolic expression on the right to -/// bring it into the range [0, d - 1]. Therefore, as in LexSimplex::addCut, +/// bring it into the range [0, d - 1]. Therefore, as in addCut(), /// k is the quotient on dividing the LHS by d, and since LHS >= 0, we have -/// k >= 0 as well. We realize the modulo of the symbolic expression by adding a -/// division variable +/// k >= 0 as well. If all the a_i are divisible by d, then we can add the +/// constraint directly. Otherwise, we realize the modulo of the symbolic +/// expression by adding a division variable /// /// q = ((-c%d) + sum_i (-a_i%d)s_i)/d /// @@ -392,16 +393,22 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { int64_t d = tableau(row, 0); - // Add the division variable `q` described above to the symbol domain. - // q = ((-c%d) + sum_i (-a_i%d)s_i)/d. + // Make the division variable `q = ((-c%d) + sum_i (-a_i%d)s_i)/d`. SmallVector divCoeffs; divCoeffs.reserve(nSymbol + 1); int64_t divDenom = d; for (unsigned col = 3; col < 3 + nSymbol; ++col) divCoeffs.push_back(mod(-tableau(row, col), divDenom)); // (-a_i%d)s_i divCoeffs.push_back(mod(-tableau(row, 1), divDenom)); // -c%d. - normalizeDiv(divCoeffs, divDenom); + + if (divDenom == 1) { + // The symbolic sample numerator is divisible by the denominator, + // so the division isn't needed. We can add the constraint directly, + // i.e., ignore the symbols and add a regular cut as in addCut(). + return addCut(row); + } + domainSimplex.addDivisionVariable(divCoeffs, divDenom); domainPoly.addLocalFloorDiv(divCoeffs, divDenom);