diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1352,6 +1352,77 @@ return true; } +/// Check if the pos^th identifier can be represented as a division using upper +/// bound inequality at position `ubIneq` and lower bound inequality at position +/// `lbIneq`. +/// +/// Let `id` be the pos^th identifier, then `id` is equivalent to +/// `expr floordiv divisor` if there are constraints of the form: +/// 0 <= expr - divisor * id <= divisor - 1 +/// Rearranging, we have: +/// divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' +/// -divisor * id + expr >= 0 <-- Upper bound for 'id' +/// +/// For example: +/// 32*k >= 16*i + j - 31 <-- Lower bound for 'k' +/// 32*k <= 16*i + j <-- Upper bound for 'k' +/// expr = 16*i + j, divisor = 32 +/// k = ( 16*i + j ) floordiv 32 +/// +/// 4q >= i + j - 2 <-- Lower bound for 'q' +/// 4q <= i + j + 1 <-- Upper bound for 'q' +/// expr = i + j + 1, divisor = 4 +/// q = (i + j + 1) floordiv 4 +/// +/// If successful, `expr` is set to dividend of the division and `divisor` is +/// set to the denominator of the division. +static LogicalResult getDivRepr(const FlatAffineConstraints &cst, unsigned pos, + unsigned ubIneq, unsigned lbIneq, + SmallVector &expr, + unsigned &divisor) { + + assert(pos <= cst.getNumIds() && "Invalid identifier position"); + assert(ubIneq <= cst.getNumInequalities() && + "Invalid upper bound inequality position"); + assert(lbIneq <= cst.getNumInequalities() && + "Invalid upper bound inequality position"); + + // Due to the form of the inequalities, sum of constants of the + // inequalities is (divisor - 1). + int64_t denominator = cst.atIneq(lbIneq, cst.getNumCols() - 1) + + cst.atIneq(ubIneq, cst.getNumCols() - 1) + 1; + + // Divisor should be positive. + if (denominator <= 0) + return failure(); + + // Check if coeff of variable is equal to divisor. + if (denominator != cst.atIneq(lbIneq, pos)) + return failure(); + + // Check if constraints are opposite of each other. Constant term + // is not required to be opposite and is not checked. + unsigned i = 0, e = 0; + for (i = 0, e = cst.getNumIds(); i < e; ++i) + if (cst.atIneq(ubIneq, i) != -cst.atIneq(lbIneq, i)) + break; + + if (i < e) + return failure(); + + // Set expr with dividend of the division. + SmallVector dividend(cst.getNumCols()); + for (i = 0, e = cst.getNumCols(); i < e; ++i) + if (i != pos) + dividend[i] = cst.atIneq(ubIneq, i); + expr = dividend; + + // Set divisor. + divisor = denominator; + + return success(); +} + /// Check if the pos^th identifier can be expressed as a floordiv of an affine /// function of other identifiers (where the divisor is a positive constant), /// `foundRepr` contains a boolean for each identifier indicating if the @@ -1366,55 +1437,22 @@ SmallVector lbIndices, ubIndices; cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices); - // `id` is equivalent to `expr floordiv divisor` if there - // are constraints of the form: - // 0 <= expr - divisor * id <= divisor - 1 - // Rearranging, we have: - // divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' - // -divisor * id + expr >= 0 <-- Upper bound for 'id' - // - // For example: - // 32*k >= 16*i + j - 31 <-- Lower bound for 'k' - // 32*k <= 16*i + j <-- Upper bound for 'k' - // expr = 16*i + j, divisor = 32 - // k = ( 16*i + j ) floordiv 32 - // - // 4q >= i + j - 2 <-- Lower bound for 'q' - // 4q <= i + j + 1 <-- Upper bound for 'q' - // expr = i + j + 1, divisor = 4 - // q = (i + j + 1) floordiv 4 for (unsigned ubPos : ubIndices) { for (unsigned lbPos : lbIndices) { - // Due to the form of the inequalities, sum of constants of the - // inequalities is (divisor - 1). - int64_t divisor = cst.atIneq(lbPos, cst.getNumCols() - 1) + - cst.atIneq(ubPos, cst.getNumCols() - 1) + 1; - - // Divisor should be positive. - if (divisor <= 0) - continue; - - // Check if coeff of variable is equal to divisor. - if (divisor != cst.atIneq(lbPos, pos)) - continue; - - // Check if constraints are opposite of each other. Constant term - // is not required to be opposite and is not checked. - unsigned c = 0, f = 0; - for (c = 0, f = cst.getNumIds(); c < f; ++c) - if (cst.atIneq(ubPos, c) != -cst.atIneq(lbPos, c)) - break; - - if (c < f) + // Attempt to get divison representation from ubPos, lbPos. + SmallVector expr; + unsigned divisor; + if (failed(getDivRepr(cst, pos, ubPos, lbPos, expr, divisor))) continue; // Check if the inequalities depend on a variable for which // an explicit representation has not been found yet. // Exit to avoid circular dependencies between divisions. + unsigned c, f; for (c = 0, f = cst.getNumIds(); c < f; ++c) { if (c == pos) continue; - if (!foundRepr[c] && cst.atIneq(lbPos, c) != 0) + if (!foundRepr[c] && expr[c] != 0) break; }