diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1215,15 +1215,28 @@ /// -divisor * id + expr >= 0 <-- Upper bound for 'id' /// /// For example: -/// 32*k >= 16*i + j - 31 <-- Lower bound for 'k' -/// 32*k <= 16*i + j <-- Upper bound for 'k' -/// expr = 16*i + j, divisor = 32 -/// k = ( 16*i + j ) floordiv 32 +/// 32*k >= 16*i + j - 31 <-- Lower bound for 'k' +/// 32*k <= 16*i + j <-- Upper bound for 'k' +/// expr = 16*i + j, divisor = 32 +/// k = ( 16*i + j ) floordiv 32 /// -/// 4q >= i + j - 2 <-- Lower bound for 'q' -/// 4q <= i + j + 1 <-- Upper bound for 'q' -/// expr = i + j + 1, divisor = 4 -/// q = (i + j + 1) floordiv 4 +/// 4q >= i + j - 2 <-- Lower bound for 'q' +/// 4q <= i + j + 1 <-- Upper bound for 'q' +/// expr = i + j + 1, divisor = 4 +/// q = (i + j + 1) floordiv 4 +// +/// This function also supports detecting divisions from bounds that are +/// strictly tighter than the division bounds described above, since tighter +/// bounds imply the division bounds. For example: +/// 4q - i - j + 2 >= 0 <-- Lower bound for 'q' +/// -4q + i + j >= 0 <-- Tight upper bound for 'q' +/// +/// To extract floor divisions with tighter bounds, we assume that that the +/// constraints are of the form: +/// c <= expr - divisior * id <= divisor - 1, where 0 <= c <= divisor - 1 +/// Rearranging, we have: +/// divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' +/// -divisor * id + expr - c >= 0 <-- Upper bound for 'id' /// /// If successful, `expr` is set to dividend of the division and `divisor` is /// set to the denominator of the division. @@ -1238,21 +1251,11 @@ assert(lbIneq <= cst.getNumInequalities() && "Invalid upper bound inequality position"); - // Due to the form of the inequalities, sum of constants of the - // inequalities is (divisor - 1). - int64_t denominator = cst.atIneq(lbIneq, cst.getNumCols() - 1) + - cst.atIneq(ubIneq, cst.getNumCols() - 1) + 1; + // Extract divisor from the lower bound. + divisor = cst.atIneq(lbIneq, pos); - // Divisor should be positive. - if (denominator <= 0) - return failure(); - - // Check if coeff of variable is equal to divisor. - if (denominator != cst.atIneq(lbIneq, pos)) - return failure(); - - // Check if constraints are opposite of each other. Constant term - // is not required to be opposite and is not checked. + // First, check if the constraints are opposite of each other except the + // constant term. unsigned i = 0, e = 0; for (i = 0, e = cst.getNumIds(); i < e; ++i) if (cst.atIneq(ubIneq, i) != -cst.atIneq(lbIneq, i)) @@ -1261,15 +1264,30 @@ if (i < e) return failure(); - // Set expr with dividend of the division. - SmallVector dividend(cst.getNumCols()); - for (i = 0, e = cst.getNumCols(); i < e; ++i) + // Then, check if the constant term is of the proper form. + // Due to the form of the upper/lower bound inequalities, the sum of their + // constants is `divisor - 1 - c`. From this, we can extract c: + int64_t constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) + + cst.atIneq(ubIneq, cst.getNumCols() - 1); + int64_t c = divisor - 1 - constantSum; + + // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. This also + // implictly checks that `divisor` is positive. + if (!(c >= 0 && c <= divisor - 1)) + return failure(); + + // The inequality pair can be used to extract the division. + // Set `expr` to the dividend of the division except the constant term, which + // is set below. + expr.resize(cst.getNumCols(), 0); + for (i = 0, e = cst.getNumIds(); i < e; ++i) if (i != pos) - dividend[i] = cst.atIneq(ubIneq, i); - expr = dividend; + expr[i] = cst.atIneq(ubIneq, i); - // Set divisor. - divisor = denominator; + // From the upper bound inequality's form, its constant term is equal to the + // constant term of `expr`, minus `c`. From this, + // constant term of `expr` = constant term of upper bound + `c`. + expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c; return success(); } diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -623,11 +623,15 @@ fac.getLocalReprs(dividends, denominators); - // Check that the `dividends` and `expectedDividends` match. - EXPECT_TRUE(expectedDividends == dividends); - // Check that the `denominators` and `expectedDenominators` match. EXPECT_TRUE(expectedDenominators == denominators); + + // Check that the `dividends` and `expectedDividends` match. If the + // denominator for a division is zero, we ignore its dividend. + EXPECT_TRUE(dividends.size() == expectedDividends.size()); + for (unsigned i = 0, e = dividends.size(); i < e; ++i) + if (denominators[i] != 0) + EXPECT_TRUE(expectedDividends[i] == dividends[i]); } TEST(FlatAffineConstraintsTest, computeLocalReprSimple) { @@ -687,6 +691,57 @@ checkDivisionRepresentation(fac, divisions, denoms); } +TEST(FlatAffineConstraintsTest, computeLocalReprTightUpperBound) { + MLIRContext context; + + { + FlatAffineConstraints fac = parseFAC("(i) : (i mod 3 - 1 >= 0)", &context); + + // The set formed by the fac is: + // 3q - i + 2 >= 0 <-- Division lower bound + // -3q + i - 1 >= 0 + // -3q + i >= 0 <-- Division upper bound + // We remove redundant constraints to get the set: + // 3q - i + 2 >= 0 <-- Division lower bound + // -3q + i - 1 >= 0 <-- Tighter division upper bound + // thus, making the upper bound tighter. + fac.removeRedundantConstraints(); + + std::vector> divisions = {{1, 0, 0}}; + SmallVector denoms = {3}; + + // Check if the divisions can be computed even with a tighter upper bound. + checkDivisionRepresentation(fac, divisions, denoms); + } + + { + FlatAffineConstraints fac = parseFAC( + "(i, j, q) : (4*q - i - j + 2 >= 0, -4*q + i + j >= 0)", &context); + // Convert `q` to a local variable. + fac.convertDimToLocal(2, 3); + + std::vector> divisions = {{1, 1, 0, 1}}; + SmallVector denoms = {4}; + + // Check if the divisions can be computed even with a tighter upper bound. + checkDivisionRepresentation(fac, divisions, denoms); + } +} + +TEST(FlatAffineConstraintsTest, computeLocalReprNoRepr) { + MLIRContext context; + FlatAffineConstraints fac = + parseFAC("(x, q) : (x - 3 * q >= 0, -x + 3 * q + 3 >= 0)", &context); + // Convert q to a local variable. + fac.convertDimToLocal(1, 2); + + std::vector> divisions = {{0, 0, 0}}; + SmallVector denoms = {0}; + + // Check that no division is computed. + checkDivisionRepresentation(fac, divisions, denoms); +} + TEST(FlatAffineConstraintsTest, simplifyLocalsTest) { // (x) : (exists y: 2x + y = 1 and y = 2). FlatAffineConstraints fac(1, 0, 1);