diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -443,16 +443,17 @@ /// identifier. Returns None if it's not a constant. This method employs /// trivial (low complexity / cost) checks and detection. Symbolic identifiers /// are treated specially, i.e., it looks for constant differences between - /// affine expressions involving only the symbolic identifiers. See comments - /// at function definition for examples. 'lb' and 'lbDivisor', if provided, - /// are used to express the lower bound associated with the constant - /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg., - /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three - /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32. + /// affine expressions involving only the symbolic identifiers. `lb` and + /// `ub` (along with the `boundFloorDivisor`) are set to represent the lower + /// and upper bound associated with the constant difference: `lb`, `ub` have + /// the coefficients, and boundFloorDivisor, their divisor. + /// Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with + /// three symbolic identifiers, *lb = [1, 0, 1], boundDivisor = 32. See + /// comments at function definition for examples. Optional getConstantBoundOnDimSize(unsigned pos, SmallVectorImpl *lb = nullptr, - int64_t *lbFloorDivisor = nullptr, + int64_t *boundFloorDivisor = nullptr, SmallVectorImpl *ub = nullptr) const; /// Returns the constant lower bound for the pos^th identifier if there is diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1201,17 +1201,30 @@ return false; } -/// Gather all lower and upper bounds of the identifier at `pos`. +/// Gather all lower and upper bounds of the identifier at `pos`. The bounds are +/// to be independent of [offset, offset + num) identifiers. static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst, unsigned pos, SmallVectorImpl *lbIndices, - SmallVectorImpl *ubIndices) { + SmallVectorImpl *ubIndices, + unsigned offset = 0, + unsigned num = 0) { assert(pos < cst.getNumIds() && "invalid position"); // Gather all lower bounds and upper bounds of the variable. Since the // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + // The bounds are to be independent of [offset, offset + num) columns. + unsigned c, f; + for (c = offset, f = offset + num; c < f; ++c) { + if (c == pos) + continue; + if (cst.atIneq(r, c) != 0) + break; + } + if (c < f) + continue; if (cst.atIneq(r, pos) >= 1) { // Lower bound. lbIndices->push_back(r); @@ -1866,7 +1879,8 @@ /// Finds an equality that equates the specified identifier to a constant. /// Returns the position of the equality row. If 'symbolic' is set to true, /// symbols are also treated like a constant, i.e., an affine function of the -/// symbols is also treated like a constant. +/// symbols is also treated like a constant. Returns -1 if such an equality +/// could not be found. static int findEqualityToConstant(const FlatAffineConstraints &cst, unsigned pos, bool symbolic = false) { assert(pos < cst.getNumIds() && "invalid position"); @@ -1937,19 +1951,15 @@ // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = // ceil(s0 - 7 / 8) = floor(s0 / 8)). Optional FlatAffineConstraints::getConstantBoundOnDimSize( - unsigned pos, SmallVectorImpl *lb, int64_t *lbFloorDivisor, + unsigned pos, SmallVectorImpl *lb, int64_t *boundFloorDivisor, SmallVectorImpl *ub) const { assert(pos < getNumDimIds() && "Invalid identifier position"); assert(getNumLocalIds() == 0); - // TODO(bondhugula): eliminate all remaining dimensional identifiers (other - // than the one at 'pos' to make this more powerful. Not needed for - // hyper-rectangular spaces. - // Find an equality for 'pos'^th identifier that equates it to some function // of the symbolic identifiers (+ constant). - int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true); - if (eqRow != -1) { + int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true); + if (eqPos != -1) { // This identifier can only take a single value. if (lb) { // Set lb to that symbolic value. @@ -1957,18 +1967,18 @@ if (ub) ub->resize(getNumSymbolIds() + 1); for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { - int64_t v = atEq(eqRow, pos); + int64_t v = atEq(eqPos, pos); // atEq(eqRow, pos) is either -1 or 1. assert(v * v == 1); - (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v - : -atEq(eqRow, getNumDimIds() + c) / v; + (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v + : -atEq(eqPos, getNumDimIds() + c) / v; // Since this is an equality, ub = lb. if (ub) (*ub)[c] = (*lb)[c]; } - assert(lbFloorDivisor && + assert(boundFloorDivisor && "both lb and divisor or none should be provided"); - *lbFloorDivisor = 1; + *boundFloorDivisor = 1; } return 1; } @@ -1990,25 +2000,9 @@ // the bounds can only involve symbolic (and local) identifiers. Since the // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. - for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - unsigned c, f; - for (c = 0, f = getNumDimIds(); c < f; c++) { - if (c != pos && atIneq(r, c) != 0) - break; - } - if (c < getNumDimIds()) - // Not a pure symbolic bound. - continue; - if (atIneq(r, pos) >= 1) - // Lower bound. - lbIndices.push_back(r); - else if (atIneq(r, pos) <= -1) - // Upper bound. - ubIndices.push_back(r); - } - - // TODO(bondhugula): eliminate other dimensional identifiers to make this more - // powerful. Not needed for hyper-rectangular iteration spaces. + getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices, + /*offset=*/0, + /*num=*/getNumDimIds()); Optional minDiff = None; unsigned minLbPosition, minUbPosition; @@ -2046,8 +2040,8 @@ // of the variable at 'pos'. We express the ceildiv equivalently as a floor // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). - *lbFloorDivisor = atIneq(minLbPosition, pos); - assert(*lbFloorDivisor == -atIneq(minUbPosition, pos)); + *boundFloorDivisor = atIneq(minLbPosition, pos); + assert(*boundFloorDivisor == -atIneq(minUbPosition, pos)); for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c); }