diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -629,6 +629,10 @@ /// The last element is the constant term. This ignores the big M coefficient. SmallVector getSymbolicSampleNumerator(unsigned row) const; + /// Get an affine inequality in the symbols with integer coefficients that + /// holds iff the symbolic sample of the specified row is non-negative. + SmallVector getSymbolicSampleIneq(unsigned row) const; + /// Return whether all the coefficients of the symbolic sample are integers. /// /// This does not consult the domain to check if the specified expression diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -130,6 +130,17 @@ SmallVectorImpl &denoms, unsigned localOffset, llvm::function_ref merge); +/// Compute the gcd of the range. +int64_t gcdRange(ArrayRef range); + +/// Divide the range by its gcd and return the gcd. +int64_t normalizeRange(MutableArrayRef range); + +/// Normalize the given (numerator, denominator) pair by dividing out the +/// common factors between them. The numerator here is an affine expression +/// with integer coefficients. +void normalizeDiv(MutableArrayRef num, int64_t &denom); + /// Return `coeffs` with all the elements negated. SmallVector getNegatedCoeffs(ArrayRef coeffs); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -151,19 +151,7 @@ /// Normalize the row by removing factors that are common between the /// denominator and all the numerator coefficients. void SimplexBase::normalizeRow(unsigned row) { - int64_t gcd = 0; - for (unsigned col = 0; col < nCol; ++col) { - gcd = llvm::greatestCommonDivisor(gcd, std::abs(tableau(row, col))); - // If the gcd becomes 1 then the row is already normalized. - if (gcd == 1) - return; - } - - // Note that the gcd can never become zero since the first element of the row, - // the denominator, is non-zero. - assert(gcd != 0); - for (unsigned col = 0; col < nCol; ++col) - tableau(row, col) /= gcd; + normalizeRange(tableau.getRow(row)); } namespace { @@ -336,6 +324,14 @@ return sample; } +SmallVector +SymbolicLexSimplex::getSymbolicSampleIneq(unsigned row) const { + SmallVector sample = getSymbolicSampleNumerator(row); + // The inequality is equivalent to the GCD-normalized one. + normalizeRange(sample); + return sample; +} + void LexSimplexBase::appendSymbol() { appendVariable(); swapColumns(3 + nSymbol, nCol - 1); @@ -391,14 +387,16 @@ // Add the division variable `q` described above to the symbol domain. // q = ((-c%d) + sum_i (-a_i%d)s_i)/d. - SmallVector domainDivCoeffs; - domainDivCoeffs.reserve(nSymbol + 1); + SmallVector divCoeffs; + divCoeffs.reserve(nSymbol + 1); + int64_t divDenom = d; for (unsigned col = 3; col < 3 + nSymbol; ++col) - domainDivCoeffs.push_back(mod(-tableau(row, col), d)); // (-a_i%d)s_i - domainDivCoeffs.push_back(mod(-tableau(row, 1), d)); // -c%d. + divCoeffs.push_back(mod(-tableau(row, col), divDenom)); // (-a_i%d)s_i + divCoeffs.push_back(mod(-tableau(row, 1), divDenom)); // -c%d. - domainSimplex.addDivisionVariable(domainDivCoeffs, d); - domainPoly.addLocalFloorDiv(domainDivCoeffs, d); + normalizeDiv(divCoeffs, divDenom); + domainSimplex.addDivisionVariable(divCoeffs, divDenom); + domainPoly.addLocalFloorDiv(divCoeffs, divDenom); // Update `this` to account for the additional symbol we just added. appendSymbol(); @@ -463,7 +461,7 @@ for (unsigned row = 0; row < nRow; ++row) { if (tableau(row, 2) > 0) continue; - if (domainSimplex.isSeparateInequality(getSymbolicSampleNumerator(row))) { + if (domainSimplex.isSeparateInequality(getSymbolicSampleIneq(row))) { // Sample numerator always takes negative values in the symbol domain. return row; } @@ -539,7 +537,7 @@ assert(tableau(splitRow, 2) == 0 && "Non-branching pivots should have been handled already!"); - symbolicSample = getSymbolicSampleNumerator(splitRow); + symbolicSample = getSymbolicSampleIneq(splitRow); if (domainSimplex.isRedundantInequality(symbolicSample)) continue; @@ -617,7 +615,8 @@ assert(u.orientation == Orientation::Row && "The split row should have been returned to row orientation!"); SmallVector splitIneq = - getComplementIneq(getSymbolicSampleNumerator(u.pos)); + getComplementIneq(getSymbolicSampleIneq(u.pos)); + normalizeRange(splitIneq); if (moveRowUnknownToColumn(u.pos).failed()) { // The unknown can't be made non-negative; return. --level; diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -304,6 +304,32 @@ } } +int64_t presburger::gcdRange(ArrayRef range) { + int64_t gcd = 0; + for (int64_t elem : range) { + gcd = llvm::greatestCommonDivisor(gcd, std::abs(elem)); + if (gcd == 1) + return gcd; + } + return gcd; +} + +int64_t presburger::normalizeRange(MutableArrayRef range) { + int64_t gcd = gcdRange(range); + if (gcd == 0 || gcd == 1) + return gcd; + for (int64_t &elem : range) + elem /= gcd; + return gcd; +} + +void presburger::normalizeDiv(MutableArrayRef num, int64_t &denom) { + int64_t gcd = llvm::greatestCommonDivisor(gcdRange(num), denom); + for (int64_t &coeff : num) + coeff /= gcd; + denom /= gcd; +} + SmallVector presburger::getNegatedCoeffs(ArrayRef coeffs) { SmallVector negatedCoeffs; negatedCoeffs.reserve(coeffs.size());