diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -519,6 +519,12 @@ /// Normalized each constraints by the GCD of its coefficients. void normalizeConstraintsByGCD(); + /// Get division representation for each local identifier. If no local + /// representation exists for the `i^th` local identifier, denominator[i] is + /// set to 0. + void getLocalIdsReprs(std::vector> &reprs, + SmallVector &denominator); + /// Removes identifiers in the column range [idStart, idLimit), and copies any /// remaining valid data into place, updates member variables, and resizes /// arrays as needed. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1352,6 +1352,77 @@ return true; } +/// Check if the pos^th identifier can be represented as a division using upper +/// bound inequality at position `ubIneq` and lower bound inequality at position +/// `lbIneq`. +/// +/// Let `id` be the pos^th identifier, then `id` is equivalent to +/// `expr floordiv divisor` if there are constraints of the form: +/// 0 <= expr - divisor * id <= divisor - 1 +/// Rearranging, we have: +/// divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' +/// -divisor * id + expr >= 0 <-- Upper bound for 'id' +/// +/// For example: +/// 32*k >= 16*i + j - 31 <-- Lower bound for 'k' +/// 32*k <= 16*i + j <-- Upper bound for 'k' +/// expr = 16*i + j, divisor = 32 +/// k = ( 16*i + j ) floordiv 32 +/// +/// 4q >= i + j - 2 <-- Lower bound for 'q' +/// 4q <= i + j + 1 <-- Upper bound for 'q' +/// expr = i + j + 1, divisor = 4 +/// q = (i + j + 1) floordiv 4 +/// +/// If successful, `expr` is set to dividend of the division and `divisor` is +/// set to the denominator of the division. +static LogicalResult getDivRepr(const FlatAffineConstraints &cst, unsigned pos, + unsigned ubIneq, unsigned lbIneq, + SmallVector &expr, + unsigned &divisor) { + + assert(pos <= cst.getNumIds() && "Invalid identifier position"); + assert(ubIneq <= cst.getNumInequalities() && + "Invalid upper bound inequality position"); + assert(lbIneq <= cst.getNumInequalities() && + "Invalid upper bound inequality position"); + + // Due to the form of the inequalities, sum of constants of the + // inequalities is (divisor - 1). + int64_t denominator = cst.atIneq(lbIneq, cst.getNumCols() - 1) + + cst.atIneq(ubIneq, cst.getNumCols() - 1) + 1; + + // Divisor should be positive. + if (denominator <= 0) + return failure(); + + // Check if coeff of variable is equal to divisor. + if (denominator != cst.atIneq(lbIneq, pos)) + return failure(); + + // Check if constraints are opposite of each other. Constant term + // is not required to be opposite and is not checked. + unsigned i = 0, e = 0; + for (i = 0, e = cst.getNumIds(); i < e; ++i) + if (cst.atIneq(ubIneq, i) != -cst.atIneq(lbIneq, i)) + break; + + if (i < e) + return failure(); + + // Set expr with dividend of the division. + SmallVector dividend(cst.getNumCols()); + for (i = 0, e = cst.getNumCols(); i < e; ++i) + if (i != pos) + dividend[i] = cst.atIneq(ubIneq, i); + expr = dividend; + + // Set divisor. + divisor = denominator; + + return success(); +} + /// Check if the pos^th identifier can be expressed as a floordiv of an affine /// function of other identifiers (where the divisor is a positive constant), /// `foundRepr` contains a boolean for each identifier indicating if the @@ -1366,55 +1437,22 @@ SmallVector lbIndices, ubIndices; cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices); - // `id` is equivalent to `expr floordiv divisor` if there - // are constraints of the form: - // 0 <= expr - divisor * id <= divisor - 1 - // Rearranging, we have: - // divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' - // -divisor * id + expr >= 0 <-- Upper bound for 'id' - // - // For example: - // 32*k >= 16*i + j - 31 <-- Lower bound for 'k' - // 32*k <= 16*i + j <-- Upper bound for 'k' - // expr = 16*i + j, divisor = 32 - // k = ( 16*i + j ) floordiv 32 - // - // 4q >= i + j - 2 <-- Lower bound for 'q' - // 4q <= i + j + 1 <-- Upper bound for 'q' - // expr = i + j + 1, divisor = 4 - // q = (i + j + 1) floordiv 4 for (unsigned ubPos : ubIndices) { for (unsigned lbPos : lbIndices) { - // Due to the form of the inequalities, sum of constants of the - // inequalities is (divisor - 1). - int64_t divisor = cst.atIneq(lbPos, cst.getNumCols() - 1) + - cst.atIneq(ubPos, cst.getNumCols() - 1) + 1; - - // Divisor should be positive. - if (divisor <= 0) - continue; - - // Check if coeff of variable is equal to divisor. - if (divisor != cst.atIneq(lbPos, pos)) - continue; - - // Check if constraints are opposite of each other. Constant term - // is not required to be opposite and is not checked. - unsigned c = 0, f = 0; - for (c = 0, f = cst.getNumIds(); c < f; ++c) - if (cst.atIneq(ubPos, c) != -cst.atIneq(lbPos, c)) - break; - - if (c < f) + // Attempt to get divison representation from ubPos, lbPos. + SmallVector expr; + unsigned divisor; + if (failed(getDivRepr(cst, pos, ubPos, lbPos, expr, divisor))) continue; // Check if the inequalities depend on a variable for which // an explicit representation has not been found yet. // Exit to avoid circular dependencies between divisions. + unsigned c, f; for (c = 0, f = cst.getNumIds(); c < f; ++c) { if (c == pos) continue; - if (!foundRepr[c] && cst.atIneq(lbPos, c) != 0) + if (!foundRepr[c] && expr[c] != 0) break; } @@ -1872,6 +1910,48 @@ equalities.resizeVertically(pos); } +void FlatAffineConstraints::getLocalIdsReprs( + std::vector> &reprs, + SmallVector &denominators) { + + assert(reprs.size() == getNumLocalIds() && + "Size of reprs must be equal to number of local ids"); + assert(denominators.size() == getNumLocalIds() && + "Size of denominators must be equal to number of local ids"); + + // Get upper-lower bound inequality pairs for division representation. + std::vector>> divIneqPairs( + getNumLocalIds()); + getLocalReprLbUbPairs(divIneqPairs); + + for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { + if (!divIneqPairs[i].hasValue()) { + denominators[i] = 0; + continue; + } + + std::pair divPair = divIneqPairs[i].getValue(); + LogicalResult divExtracted = + getDivRepr(*this, i + getIdKindOffset(IdKind::Local), divPair.first, + divPair.second, reprs[i], denominators[i]); + assert(succeeded(divExtracted) && + "Div should have been found since ub-lb pair exists"); + } +} + +/// Merge local identifer at `pos2` into local identifer at `pos1` in `fac`. +static void mergeDivision(FlatAffineConstraints &fac, unsigned pos1, + unsigned pos2) { + unsigned localOffset = fac.getNumDimAndSymbolIds(); + pos1 += localOffset; + pos2 += localOffset; + for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i) + fac.atIneq(i, pos1) += fac.atIneq(i, pos2); + for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i) + fac.atEq(i, pos1) += fac.atEq(i, pos2); + fac.removeId(pos2); +} + /// Merge local ids of `this` and `other`. This is done by appending local ids /// of `other` to `this` and inserting local ids of `this` to `other` at start /// of its local ids. Number of dimension and symbol ids should match in @@ -1881,9 +1961,67 @@ "Number of dimension ids should match"); assert(getNumSymbolIds() == other.getNumSymbolIds() && "Number of symbol ids should match"); - unsigned initLocals = getNumLocalIds(); - insertLocalId(getNumLocalIds(), other.getNumLocalIds()); - other.insertLocalId(0, initLocals); + + FlatAffineConstraints &fac1 = *this; + FlatAffineConstraints &fac2 = other; + + // Get divisions inequality pairs from each FAC. + std::vector> divs1(fac1.getNumLocalIds()), + divs2(fac2.getNumLocalIds()); + SmallVector denoms1(fac1.getNumLocalIds()), + denoms2(fac2.getNumLocalIds()); + fac1.getLocalIdsReprs(divs1, denoms1); + fac2.getLocalIdsReprs(divs2, denoms2); + + // Merge local ids of fac1 and fac2 without using division information, + // i.e. append local ids of `fac2` to `fac1` and insert local ids of `fac1` + // to `fac2` at start of its local ids. + unsigned initLocals = fac1.getNumLocalIds(); + insertLocalId(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + fac2.insertLocalId(0, initLocals); + + // Merge division representation extracted from fac1 and fac2. + divs1.insert(divs1.end(), divs2.begin(), divs2.end()); + denoms1.insert(denoms1.end(), denoms2.begin(), denoms2.end()); + + auto dependsOnExist = [&](unsigned offset, SmallVector &div) { + for (unsigned i = offset, e = div.size(); i < e; ++i) + if (div[i] != 0) + return true; + return false; + }; + + // Find duplicate divisions and merge them. + // TODO: Add division normalization to support divisions that differ by + // a constant + for (unsigned i = 0; i < divs1.size(); ++i) { + // Check if a division exists which is duplicate of division at `i`. + for (unsigned j = i + 1; j < divs1.size(); ++j) { + // Check if division representation exists for both local ids. + if (denoms1[i] == 0 || denoms1[j] == 0) + continue; + // Check if denominators match. + if (denoms1[i] != denoms1[j]) + continue; + // Check if representation is equal. + if (!std::equal(divs1[i].begin(), divs1[i].end(), divs1[j].begin())) + continue; + // If division representation contains a local variable, do not match. + // TODO: Support divisions that depend on other local ids. This can + // be done by ordering divisions such that a division representation + // for local identifier at position `i` only depends on local identifiers + // at position < `i`. + if (dependsOnExist(fac1.getIdKindOffset(IdKind::Local), divs1[j])) + continue; + + // Merge divisions at position `j` into division at position `i`. + mergeDivision(fac1, i, j); + mergeDivision(fac2, i, j); + divs1.erase(divs1.begin() + j); + denoms1.erase(denoms1.begin() + j); + --j; + } + } } /// Removes local variables using equalities. Each equality is checked if it diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -795,4 +795,79 @@ EXPECT_TRUE(fac3.isEmpty()); } +TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) { + { + // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). + FlatAffineConstraints fac1(1, 0, 1); + fac1.addLocalFloorDiv({1, 0, 0}, 2); + fac1.addEquality({1, 0, -3, 0}); + fac1.addInequality({1, 1, 0, 1}); + + // (x) : (exists y = [x / 2], z : x = 5y). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); + fac2.addEquality({1, -5, 0}); + fac2.appendLocalId(); + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division matched + 2 unmatched local variables. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } + + { + // (x) : (exists z = [x / 5], y = [x / 2] : x = 3y). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 5); + fac1.addLocalFloorDiv({1, 0, 0}, 2); + fac1.addEquality({1, 0, -3, 0}); + + // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); + fac2.addLocalFloorDiv({1, 0, 0}, 5); + fac2.addEquality({1, 0, -5, 0}); + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 division matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } +} + +TEST(FlatAffineConstraintsTest, mergeDivisionsUnsupported) { + // Division merging for divisions depending on other local variables + // not yet supported. + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 2); + fac1.addLocalFloorDiv({1, 1, 0}, 3); + fac1.addInequality({-1, 1, 1, 0}); + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); + fac2.addLocalFloorDiv({1, 1, 0}, 3); + fac2.addInequality({1, -1, -1, 0}); + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division matched + 2 unmerged division due to division depending on + // other local variables. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); +} + } // namespace mlir