diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1256,36 +1256,40 @@ /// explicit representation for that identifier has already been computed. static Optional> computeSingleVarRepr(const FlatAffineConstraints &cst, - std::vector &foundRepr, unsigned pos) { + const std::vector &foundRepr, unsigned pos) { assert(pos < cst.getNumIds() && "invalid position"); + assert(foundRepr.size() == cst.getNumIds()); SmallVector lbIndices, ubIndices; cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices); - // Check if any lower bound, upper bound pair is of the form: - // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' - // divisor * id <= expr <-- Upper bound for 'id' - // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). + // `id` is equivalent to `expr floordiv divisor` if there + // are constraints of the form: + // 0 <= divisor * id - expr <= divisor - 1 + // Rearranging, we have: + // divisor * id - expr >= 0 <-- Lower bound for 'id' + // -divisor * id + expr + (divisor - 1) >= 0 <-- Upper bound for 'id' // // For example, if -32*k + 16*i + j >= 0 // 32*k - 16*i - j + 31 >= 0 <=> // k = ( 16*i + j ) floordiv 32 - for (auto ubPos : ubIndices) { - for (auto lbPos : lbIndices) { + for (unsigned ubPos : ubIndices) { + for (unsigned lbPos : lbIndices) { // Due to the form of the inequalities, sum of constants of the // inequalities is (divisor - 1) - int64_t constantSum = cst.atIneq(lbPos, cst.getNumCols() - 1) + - cst.atIneq(ubPos, cst.getNumCols() - 1); + int64_t divisor = cst.atIneq(lbPos, cst.getNumCols() - 1) + + cst.atIneq(ubPos, cst.getNumCols() - 1) + 1; - // Check if (sum of constants > 0). Indirectly checking (divisor > 1). - if (constantSum <= 0) + // Divisor should be positive. + if (divisor <= 0) continue; - // Check if coeff of variable is equal to divisor - if (constantSum + 1 != cst.atIneq(lbPos, pos)) + // Check if coeff of variable is equal to divisor. + if (divisor != cst.atIneq(lbPos, pos)) continue; - // Check if constraints are opposite of each other. + // Check if constraints are opposite of each other. Constant term + // is not required to be opposite and is not checked. unsigned c = 0, f = 0; for (c = 0, f = cst.getNumIds(); c < f; ++c) { if (cst.atIneq(ubPos, c) != -cst.atIneq(lbPos, c)) @@ -1296,7 +1300,7 @@ // Check if the inequalities depend on a variable for which // an explicit representation has not been found yet. - // Exit to avoid circular dependencies between divisions + // Exit to avoid circular dependencies between divisions. for (c = 0, f = cst.getNumIds(); c < f; ++c) { if (c == pos) continue; @@ -1311,7 +1315,6 @@ if (c < f) continue; - foundRepr[pos] = true; return llvm::Optional>({ubPos, lbPos}); } } @@ -1328,7 +1331,7 @@ std::vector>> repr( getNumLocalIds(), llvm::None); - std::vector foundRepr(getNumCols() - 1, false); + std::vector foundRepr(getNumIds(), false); for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) foundRepr[i] = true; @@ -1341,11 +1344,11 @@ changed = false; for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { if (!foundRepr[i + divOffset]) { - auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i); - if (!res) - continue; - repr[i] = res; - changed = false; + if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i)) { + foundRepr[i + divOffset] = true; + repr[i] = res; + changed = true; + } } } } while (changed); @@ -1621,10 +1624,10 @@ // Get upper-lower bound pair for this variable std::vector foundRepr(cst.getNumIds(), false); - for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i) { + for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i) if (exprs[i]) foundRepr[i] = true; - } + auto ulPair = computeSingleVarRepr(cst, foundRepr, pos); // No upper-lower bound pair found for this var @@ -1634,9 +1637,9 @@ unsigned ubPos = ulPair->first, lbPos = ulPair->second; // Lower bound is of the form: - // divisor * id >= expr - (divisor - 1) + // divisor * id >= expr // - // The divisor is the coefficent of the division + // The divisor is the coefficent of the division variable // the constant term for the division is: // constantTerm = constantTermOf(expr) - (divisor - 1) int64_t divisor = cst.atIneq(lbPos, pos); diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -593,7 +593,7 @@ // If all is set, check if all divisions are computed if (all) { - for (auto &r : res) + for (auto r : res) EXPECT_NE(r, llvm::None); } @@ -602,9 +602,9 @@ if (!res[i]) continue; - // Check if bounds are of the form: - // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' - // divisor * id <= expr <-- Upper bound for 'id' + // Check if bounds are of the form: + // divisor * id >= expr <-- Lower bound for 'id' + // divisor * id <= expr - (divisor - 1) <-- Upper bound for 'id' unsigned ubPos = res[i]->first, lbPos = res[i]->second; int64_t constantSum = fac.atIneq(lbPos, fac.getNumCols() - 1) + fac.atIneq(ubPos, fac.getNumCols() - 1); @@ -625,7 +625,11 @@ FlatAffineConstraints fac = makeFACFromConstraints(1, {}, {}); fac.addLocalFloorDiv({1, 4}, 10); - fac.addLocalFloorDiv({1, 10, 100}, 10); + fac.addLocalFloorDiv({1, 0, 100}, 10); + + // Check if floordivs can be recovered when no other inequalities exist + // and floor divs do not depend on each other. + checkDivisionRepresentation(fac); } TEST(FlatAffineConstraintsTest, computeLocalReprConstantFloorDiv) { @@ -634,8 +638,11 @@ fac.addInequality({1, 0, 3, 1, 2}); fac.addInequality({1, 2, -8, 1, 10}); fac.addEquality({1, 2, -4, 1, 10}); + fac.addLocalFloorDiv({0, 0, 0, 0, 10}, 30); + fac.addLocalFloorDiv({0, 0, 0, 0, 0, 99}, 101); + // Check if floordivs with constant numerator can be recovered. checkDivisionRepresentation(fac); } @@ -653,6 +660,7 @@ fac.addInequality({1, 2, -2, 1, -5, 0, 6, 100}); fac.addInequality({1, 2, -8, 1, 3, 7, 0, -9}); + // Check if floordivs which may depend on other floordivs can be recovered. checkDivisionRepresentation(fac); }