diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -157,11 +157,22 @@ bool containsPoint(ArrayRef point) const; /// Find pairs of inequalities identified by their position indices, using - /// which an explicit representation for each local variable can be computed - /// The pairs are stored as indices of upperbound, lowerbound - /// inequalities. If no such pair can be found, it is stored as llvm::None. - void getLocalReprLbUbPairs( + /// which an explicit representation for each local variable can be computed. + /// The pairs are stored as indices of upperbound, lowerbound inequalities. If + /// no such pair can be found, it is stored as llvm::None. + /// + /// The dividends of the explicit representations are stored in `dividends` + /// and the denominators in `denominators`. If no explicit representation + /// could be found for the `i^th` local identifier, `denominators[i]` is set + /// to 0. + void getLocalReprs( + std::vector> ÷nds, + SmallVector &denominators, + std::vector>> &repr) const; + void getLocalReprs( std::vector>> &repr) const; + void getLocalReprs(std::vector> ÷nds, + SmallVector &denominators) const; // Clones this object. std::unique_ptr clone() const; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1424,12 +1424,17 @@ } /// Check if the pos^th identifier can be expressed as a floordiv of an affine -/// function of other identifiers (where the divisor is a positive constant), +/// function of other identifiers (where the divisor is a positive constant). /// `foundRepr` contains a boolean for each identifier indicating if the /// explicit representation for that identifier has already been computed. +/// Returns the upper and lower bound inequalities using which the floordiv can +/// be computed. If the representation could be computed, `dividend` and +/// `denominator` are set. If the representation could not be computed, +/// `llvm::None` is returned. static Optional> computeSingleVarRepr(const FlatAffineConstraints &cst, - const SmallVector &foundRepr, unsigned pos) { + const SmallVector &foundRepr, unsigned pos, + SmallVector ÷nd, unsigned &divisor) { assert(pos < cst.getNumIds() && "invalid position"); assert(foundRepr.size() == cst.getNumIds() && "Size of foundRepr does not match total number of variables"); @@ -1440,9 +1445,7 @@ for (unsigned ubPos : ubIndices) { for (unsigned lbPos : lbIndices) { // Attempt to get divison representation from ubPos, lbPos. - SmallVector expr; - unsigned divisor; - if (failed(getDivRepr(cst, pos, ubPos, lbPos, expr, divisor))) + if (failed(getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor))) continue; // Check if the inequalities depend on a variable for which @@ -1452,7 +1455,7 @@ for (c = 0, f = cst.getNumIds(); c < f; ++c) { if (c == pos) continue; - if (!foundRepr[c] && expr[c] != 0) + if (!foundRepr[c] && dividend[c] != 0) break; } @@ -1470,14 +1473,29 @@ return llvm::None; } -/// Find pairs of inequalities identified by their position indices, using -/// which an explicit representation for each local variable can be computed -/// The pairs are stored as indices of upperbound, lowerbound -/// inequalities. If no such pair can be found, it is stored as llvm::None. -void FlatAffineConstraints::getLocalReprLbUbPairs( +void FlatAffineConstraints::getLocalReprs( std::vector>> &repr) const { - assert(repr.size() == getNumLocalIds() && - "Size of repr does not match number of local variables"); + std::vector> dividends(getNumLocalIds()); + SmallVector denominators(getNumLocalIds()); + getLocalReprs(dividends, denominators, repr); +} + +void FlatAffineConstraints::getLocalReprs( + std::vector> ÷nds, + SmallVector &denominators) const { + std::vector>> repr( + getNumLocalIds()); + getLocalReprs(dividends, denominators, repr); +} + +void FlatAffineConstraints::getLocalReprs( + std::vector> ÷nds, + SmallVector &denominators, + std::vector>> &repr) const { + + repr.resize(getNumLocalIds()); + dividends.resize(getNumLocalIds()); + denominators.resize(getNumLocalIds()); SmallVector foundRepr(getNumIds(), false); for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) @@ -1491,7 +1509,8 @@ changed = false; for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { if (!foundRepr[i + divOffset]) { - if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i)) { + if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i, + dividends[i], denominators[i])) { foundRepr[i + divOffset] = true; repr[i] = res; changed = true; @@ -1499,6 +1518,12 @@ } } } while (changed); + + // Set 0 denominator for identifiers for which no division representation + // could be found. + for (unsigned i = 0, e = repr.size(); i < e; ++i) + if (!repr[i].hasValue()) + denominators[i] = 0; } /// Tightens inequalities given that we are dealing with integer spaces. This is @@ -1774,36 +1799,19 @@ if (exprs[i]) foundRepr[i] = true; - auto ulPair = computeSingleVarRepr(cst, foundRepr, pos); + SmallVector dividend; + unsigned divisor; + auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); // No upper-lower bound pair found for this var. if (!ulPair) return false; - unsigned ubPos = ulPair->first; - - // Upper bound is of the form: - // -divisor * id + expr >= 0 - // where `id` is equivalent to `expr floordiv divisor`. - // - // Since the division cannot be dependent on itself, the coefficient of - // of `id` in `expr` is zero. The coefficient of `id` in the upperbound - // is -divisor. - int64_t divisor = -cst.atIneq(ubPos, pos); - int64_t constantTerm = cst.atIneq(ubPos, cst.getNumCols() - 1); - // Construct the dividend expression. - auto dividendExpr = getAffineConstantExpr(constantTerm, context); - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (c == pos) - continue; - int64_t ubVal = cst.atIneq(ubPos, c); - if (ubVal == 0) - continue; - // computeSingleVarRepr guarantees that expr is known here. - dividendExpr = dividendExpr + ubVal * exprs[c]; - } + auto dividendExpr = getAffineConstantExpr(dividend.back(), context); + for (unsigned c = 0, f = cst.getNumIds(); c < f; c++) + if (dividend[c] != 0) + dividendExpr = dividendExpr + dividend[c] * exprs[c]; // Successfully detected the floordiv. exprs[pos] = dividendExpr.floorDiv(divisor); diff --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp --- a/mlir/lib/Analysis/PresburgerSet.cpp +++ b/mlir/lib/Analysis/PresburgerSet.cpp @@ -196,7 +196,7 @@ // the local variables of sI. std::vector>> repr( sI.getNumLocalIds()); - sI.getLocalReprLbUbPairs(repr); + sI.getLocalReprs(repr); // Add sI's locals to b, after b's locals. Also add b's locals to sI, before // sI's locals. diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -648,7 +648,7 @@ std::vector>> res( fac.getNumLocalIds(), llvm::None); - fac.getLocalReprLbUbPairs(res); + fac.getLocalReprs(res); // Check if all expected divisions are computed. for (unsigned i = 0, e = fac.getNumLocalIds(); i < e; ++i)