diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -160,6 +160,13 @@ /// otherwise. bool containsPoint(ArrayRef point) const; + /// Find pairs of indices of inequalities, using which an explicit + /// representation for each local variable can be computed. + /// Pairs are returned as indices of upperbound, lowerbound equalities. + /// If no such pair is found, pair is marked as None. + std::vector>> + computeLocalRepr() const; + // Clones this object. std::unique_ptr clone() const; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1250,6 +1250,109 @@ return true; } +/// Check if the pos^th identifier can be expressed as a floordiv of an affine +/// function of other identifiers (where the divisor is a positive constant), +/// foundRepr contains a boolean for each identifier indicating if the +/// explicit representation for that identifier has already been computed. +static Optional> +computeSingleVarRepr(const FlatAffineConstraints &cst, + std::vector &foundRepr, unsigned pos) { + assert(pos < cst.getNumIds() && "invalid position"); + + SmallVector lbIndices, ubIndices; + cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices); + + // Check if any lower bound, upper bound pair is of the form: + // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' + // divisor * id <= expr <-- Upper bound for 'id' + // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). + // + // For example, if -32*k + 16*i + j >= 0 + // 32*k - 16*i - j + 31 >= 0 <=> + // k = ( 16*i + j ) floordiv 32 + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + // Due to the form of the inequalities, sum of constants of the + // inequalities is (divisor - 1) + int64_t constantSum = cst.atIneq(lbPos, cst.getNumCols() - 1) + + cst.atIneq(ubPos, cst.getNumCols() - 1); + + // Check if (sum of constants > 0). Indirectly checking (divisor > 1). + if (constantSum <= 0) + continue; + + // Check if coeff of variable is equal to divisor + if (constantSum + 1 != cst.atIneq(lbPos, pos)) + continue; + + // Check if constraints are opposite of each other. + unsigned c = 0, f = 0; + for (c = 0, f = cst.getNumIds(); c < f; ++c) { + if (cst.atIneq(ubPos, c) != -cst.atIneq(lbPos, c)) + break; + } + if (c < f) + continue; + + // Check if the inequalities depend on a variable for which + // an explicit representation has not been found yet. + // Exit to avoid circular dependencies between divisions + for (c = 0, f = cst.getNumIds(); c < f; ++c) { + if (c == pos) + continue; + if (!foundRepr[c] && cst.atIneq(lbPos, c) != 0) + break; + } + + // Expression can't be constructed as it depends on a yet unknown + // identifier. + // TODO: Visit/compute the identifiers in an order so that this doesn't + // happen. More complex but much more efficient. + if (c < f) + continue; + + foundRepr[pos] = true; + return llvm::Optional>({ubPos, lbPos}); + } + } + + return llvm::None; +} + +/// Find pairs of indices of inequalities, using which an explicit +/// representation for each local variable can be computed. +/// Pairs are returned as indices of upperbound, lowerbound equalities. +/// If no such pair is found, pair is marked as None. +std::vector>> +FlatAffineConstraints::computeLocalRepr() const { + + std::vector>> repr( + getNumLocalIds(), llvm::None); + std::vector foundRepr(getNumCols() - 1, false); + + for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) + foundRepr[i] = true; + + unsigned divOffset = getNumDimAndSymbolIds(); + bool changed; + do { + // Each time changed is true, at end of this iteration, one or more local + // vars have been detected as floor divs. + changed = false; + for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { + if (!foundRepr[i + divOffset]) { + auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i); + if (!res) + continue; + repr[i] = res; + changed = false; + } + } + } while (changed); + + return repr; +} + /// Tightens inequalities given that we are dealing with integer spaces. This is /// analogous to the GCD test but applied to inequalities. The constant term can /// be reduced to the preceding multiple of the GCD of the coefficients, i.e., @@ -1516,70 +1619,46 @@ SmallVectorImpl &exprs) { assert(pos < cst.getNumIds() && "invalid position"); - SmallVector lbIndices, ubIndices; - cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices); + // Get upper-lower bound pair for this variable + std::vector foundRepr(cst.getNumIds(), false); + for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i) { + if (exprs[i]) + foundRepr[i] = true; + } + auto ulPair = computeSingleVarRepr(cst, foundRepr, pos); - // Check if any lower bound, upper bound pair is of the form: - // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' - // divisor * id <= expr <-- Upper bound for 'id' - // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). + // No upper-lower bound pair found for this var + if (!ulPair) + return false; + + unsigned ubPos = ulPair->first, lbPos = ulPair->second; + + // Lower bound is of the form: + // divisor * id >= expr - (divisor - 1) // - // For example, if -32*k + 16*i + j >= 0 - // 32*k - 16*i - j + 31 >= 0 <=> - // k = ( 16*i + j ) floordiv 32 - unsigned seenDividends = 0; - for (auto ubPos : ubIndices) { - for (auto lbPos : lbIndices) { - // Check if the lower bound's constant term is divisor - 1. The - // 'divisor' here is cst.atIneq(lbPos, pos) and we already know that it's - // positive (since cst.Ineq(lbPos, ...) is a lower bound expr for 'pos'. - int64_t divisor = cst.atIneq(lbPos, pos); - int64_t lbConstTerm = cst.atIneq(lbPos, cst.getNumCols() - 1); - if (lbConstTerm != divisor - 1) - continue; - // Check if upper bound's constant term is 0. - if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) - continue; - // For the remaining part, check if the lower bound expr's coeff's are - // negations of corresponding upper bound ones'. - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) - break; - if (c != pos && cst.atIneq(lbPos, c) != 0) - seenDividends++; - } - // Lb coeff's aren't negative of ub coeff's (for the non constant term - // part). - if (c < f) - continue; - if (seenDividends >= 1) { - // Construct the dividend expression. - auto dividendExpr = getAffineConstantExpr(0, context); - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (c == pos) - continue; - int64_t ubVal = cst.atIneq(ubPos, c); - if (ubVal == 0) - continue; - if (!exprs[c]) - break; - dividendExpr = dividendExpr + ubVal * exprs[c]; - } - // Expression can't be constructed as it depends on a yet unknown - // identifier. - // TODO: Visit/compute the identifiers in an order so that this doesn't - // happen. More complex but much more efficient. - if (c < f) - continue; - // Successfully detected the floordiv. - exprs[pos] = dividendExpr.floorDiv(divisor); - return true; - } - } + // The divisor is the coefficent of the division + // the constant term for the division is: + // constantTerm = constantTermOf(expr) - (divisor - 1) + int64_t divisor = cst.atIneq(lbPos, pos); + int64_t constantTerm = + cst.atIneq(lbPos, cst.getNumCols() - 1) - (divisor - 1); + + // Construct the dividend expression. + auto dividendExpr = getAffineConstantExpr(constantTerm, context); + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (c == pos) + continue; + int64_t ubVal = cst.atIneq(ubPos, c); + if (ubVal == 0) + continue; + // computeSingleVarRepr gurantees that expr is known here + dividendExpr = dividendExpr + ubVal * exprs[c]; } - return false; + + // Successfully detected the floordiv. + exprs[pos] = dividendExpr.floorDiv(divisor); + return true; } // Fills an inequality row with the value 'val'. diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -587,4 +587,73 @@ EXPECT_EQ(fac.atIneq(0, 1), 0); } +void checkDivisionRepresentation(FlatAffineConstraints &fac, bool all = true) { + // Check if representation can be computed for all local variables + auto res = fac.computeLocalRepr(); + + // If all is set, check if all divisions are computed + if (all) { + for (auto &r : res) + EXPECT_NE(r, llvm::None); + } + + unsigned divOffset = fac.getNumDimAndSymbolIds(); + for (unsigned i = 0, e = fac.getNumLocalIds(); i < e; ++i) { + if (!res[i]) + continue; + + // Check if bounds are of the form: + // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' + // divisor * id <= expr <-- Upper bound for 'id' + unsigned ubPos = res[i]->first, lbPos = res[i]->second; + int64_t constantSum = fac.atIneq(lbPos, fac.getNumCols() - 1) + + fac.atIneq(ubPos, fac.getNumCols() - 1); + + // Check if (sum of constants > 0). Indirectly checking (divisor > 1). + EXPECT_TRUE(constantSum > 0); + + // Check if coeff of variable is equal to divisor + EXPECT_EQ(constantSum + 1, fac.atIneq(lbPos, i + divOffset)); + + // Check if constraints are opposite of each other. + for (unsigned c = 0, f = fac.getNumIds(); c < f; ++c) + EXPECT_EQ(fac.atIneq(ubPos, c), -fac.atIneq(lbPos, c)); + } +} + +TEST(FlatAffineConstraintsTest, computeLocalReprSimple) { + FlatAffineConstraints fac = makeFACFromConstraints(1, {}, {}); + + fac.addLocalFloorDiv({1, 4}, 10); + fac.addLocalFloorDiv({1, 10, 100}, 10); +} + +TEST(FlatAffineConstraintsTest, computeLocalReprConstantFloorDiv) { + FlatAffineConstraints fac = makeFACFromConstraints(4, {}, {}); + + fac.addInequality({1, 0, 3, 1, 2}); + fac.addInequality({1, 2, -8, 1, 10}); + fac.addEquality({1, 2, -4, 1, 10}); + fac.addLocalFloorDiv({0, 0, 0, 0, 10}, 30); + + checkDivisionRepresentation(fac); +} + +TEST(FlatAffineConstraintsTest, computeLocalReprRecursive) { + FlatAffineConstraints fac = makeFACFromConstraints(4, {}, {}); + fac.addInequality({1, 0, 3, 1, 2}); + fac.addInequality({1, 2, -8, 1, 10}); + fac.addInequality({1, 2, -8, 1, 10}); + fac.addEquality({1, 2, -4, 1, 10}); + + fac.addLocalFloorDiv({0, -2, 7, 2, 10}, 3); + fac.addLocalFloorDiv({3, 0, 9, 2, 2, 10}, 5); + fac.addLocalFloorDiv({0, 1, -123, 2, 0, -4, 10}, 3); + + fac.addInequality({1, 2, -2, 1, -5, 0, 6, 100}); + fac.addInequality({1, 2, -8, 1, 3, 7, 0, -9}); + + checkDivisionRepresentation(fac); +} + } // namespace mlir