diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1374,6 +1374,21 @@ /// expr = i + j + 1, divisor = 4 /// q = (i + j + 1) floordiv 4 /// +/// Sometimes, due to additional constraints, the original upper bound of the +/// floor division may be removed. For example: +/// -3q + i - 1 >= 0 <-- `ineq1` +/// 3q - i - 2 >= 0 <-- Lower bound for 'q' +/// -3q + i >= 0 <-- Upper bound for 'q' +/// +/// Here, `ineq1` makes the upper bound for 'q' redundant, and thus the upper +/// bound may be removed. To extract floor divisions in cases where `ineq` is +/// given as `ubIneq`, we relax the conditions on upper bound for the local +/// variable to: +/// -divisor * id + expr - c >= 0, where 0 <= c <= divisor - 1 +/// +/// So the final condition we need to check is: +/// c <= expr - divisor * id <= divisor - 1, where 0 <= c <= divisor - 1 +/// /// If successful, `expr` is set to dividend of the division and `divisor` is /// set to the denominator of the division. static LogicalResult getDivRepr(const FlatAffineConstraints &cst, unsigned pos, @@ -1387,17 +1402,32 @@ assert(lbIneq <= cst.getNumInequalities() && "Invalid upper bound inequality position"); - // Due to the form of the inequalities, sum of constants of the - // inequalities is (divisor - 1). - int64_t denominator = cst.atIneq(lbIneq, cst.getNumCols() - 1) + - cst.atIneq(ubIneq, cst.getNumCols() - 1) + 1; + // Due to the form of the upper/lower bound inequalities, their sum of + // constants of is `divisor - 1 - c` using the relaxed lower bound. + // We check the condition `0 <= c <= divisor - 1` by checking two conditions + // on sum of constants: + // + // 1. Sum of constants >= 0 + // divisor - 1 - c >= 0 + // divisor - 1 >= c + // + // 2. Sum of constants <= divisor - 1 + // divisor - 1 - c <= divisor - 1 + // c >= 0 + // + // These two conditions together check the conditions on c: + // 0 <= c <= divisor - 1 + // + // which satisfies the relaxed upper bound inequality form. + int64_t constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) + + cst.atIneq(ubIneq, cst.getNumCols() - 1); - // Divisor should be positive. - if (denominator <= 0) + // constantSum should be non negative. + if (constantSum < 0) return failure(); - // Check if coeff of variable is equal to divisor. - if (denominator != cst.atIneq(lbIneq, pos)) + // constantSum should be less than equal to divisor - 1. + if (constantSum >= cst.atIneq(lbIneq, pos)) return failure(); // Check if constraints are opposite of each other. Constant term @@ -1410,15 +1440,20 @@ if (i < e) return failure(); - // Set expr with dividend of the division. - SmallVector dividend(cst.getNumCols()); - for (i = 0, e = cst.getNumCols(); i < e; ++i) + // Extract divisor from the lower bound. + divisor = cst.atIneq(lbIneq, pos); + + // Set expr with dividend of the division except constant term. + SmallVector dividend(cst.getNumCols(), 0); + for (i = 0, e = cst.getNumIds(); i < e; ++i) if (i != pos) dividend[i] = cst.atIneq(ubIneq, i); - expr = dividend; - // Set divisor. - divisor = denominator; + // Set constant term of dividend. From the lower bound form: + // constant term of dividend = (divisor - 1) - constant term of lower bound. + dividend.back() = (divisor - 1) - cst.atIneq(lbIneq, cst.getNumCols() - 1); + + expr = dividend; return success(); } diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -677,10 +677,10 @@ continue; // Check if the bounds are of the form: - // 0 <= expr - divisor * id <= divisor - 1 + // c <= expr - divisor * id <= divisor - 1, where 0 <= c <= divisor - 1 // Rearranging, we have: // divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' - // -divisor * id + expr >= 0 <-- Upper bound for 'id' + // -divisor * id + expr - c >= 0 <-- Upper bound for 'id' // where `id = expr floordiv divisor`. unsigned ubPos = res[i]->first, lbPos = res[i]->second; const SmallVector &expr = divisions[i]; @@ -705,8 +705,11 @@ continue; EXPECT_EQ(fac.atIneq(ubPos, c), expr[c]); } + // Check if constant term of upper bound matches expected constant term. - EXPECT_EQ(fac.atIneq(ubPos, fac.getNumCols() - 1), expr.back()); + int64_t ubIneqC = -(fac.atIneq(ubPos, fac.getNumCols() - 1) - expr.back()); + EXPECT_TRUE(ubIneqC >= 0); + EXPECT_TRUE(ubIneqC <= denoms[i] - 1); } } @@ -765,6 +768,21 @@ checkDivisionRepresentation(fac, divisions, denoms); } +TEST(FlatAffineConstraintsTest, computeLocalReprRedundantUpperBound) { + MLIRContext context; + FlatAffineConstraints fac = parseFAC("(d0) : (d0 mod 3 - 1 >= 0)", &context); + + // Remove redundant constraints to check if the division computation can + // handle removal of redundant upper bound. + fac.removeRedundantConstraints(); + + std::vector> divisions = {{1, 0, 0}}; + SmallVector denoms = {3}; + + // Check if floordivs which may depend on other floordivs can be computed. + checkDivisionRepresentation(fac, divisions, denoms); +} + TEST(FlatAffineConstraintsTest, removeIdRange) { FlatAffineConstraints fac(3, 2, 1);