diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -160,8 +160,16 @@ /// which an explicit representation for each local variable can be computed /// The pairs are stored as indices of upperbound, lowerbound /// inequalities. If no such pair can be found, it is stored as llvm::None. - void getLocalReprLbUbPairs( + /// The dividend of the explicit representations are stored in `dividends` + /// and the denominators in `denominator`. + void getLocalReprs( + std::vector> ÷nds, + SmallVector &denominator, std::vector>> &repr) const; + void getLocalReprs( + std::vector>> &repr) const; + void getLocalReprs(std::vector> ÷nds, + SmallVector &denominator) const; // Clones this object. std::unique_ptr clone() const; @@ -430,10 +438,9 @@ /// variables. void convertDimToLocal(unsigned dimStart, unsigned dimLimit); - /// Merge local ids of `this` and `other`. This is done by appending local ids - /// of `other` to `this` and inserting local ids of `this` to `other` at start - /// of its local ids. Number of dimension and symbol ids should match in - /// `this` and `other`. + /// Merges and aligns local ids of `this` and `other`. Local ids with + /// identical division representations are merged. The number of dimensions + /// and symbol ids should match in `this` and `other`. void mergeLocalIds(FlatAffineConstraints &other); /// Removes all equalities and inequalities. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1352,13 +1352,89 @@ return true; } +/// Check if the pos^th identifier can be represented as a division using upper +/// bound inequality at position `ubIneq` and lower bound inequality at position +/// `lbIneq`. +/// +/// Let `id` be the pos^th identifier, then `id` is equivalent to +/// `expr floordiv divisor` if there are constraints of the form: +/// 0 <= expr - divisor * id <= divisor - 1 +/// Rearranging, we have: +/// divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' +/// -divisor * id + expr >= 0 <-- Upper bound for 'id' +/// +/// For example: +/// 32*k >= 16*i + j - 31 <-- Lower bound for 'k' +/// 32*k <= 16*i + j <-- Upper bound for 'k' +/// expr = 16*i + j, divisor = 32 +/// k = ( 16*i + j ) floordiv 32 +/// +/// 4q >= i + j - 2 <-- Lower bound for 'q' +/// 4q <= i + j + 1 <-- Upper bound for 'q' +/// expr = i + j + 1, divisor = 4 +/// q = (i + j + 1) floordiv 4 +/// +/// If successful, `expr` is set to dividend of the division and `divisor` is +/// set to the denominator of the division. +static LogicalResult getDivRepr(const FlatAffineConstraints &cst, unsigned pos, + unsigned ubIneq, unsigned lbIneq, + SmallVector &expr, + unsigned &divisor) { + + assert(pos <= cst.getNumIds() && "Invalid identifier position"); + assert(ubIneq <= cst.getNumInequalities() && + "Invalid upper bound inequality position"); + assert(lbIneq <= cst.getNumInequalities() && + "Invalid upper bound inequality position"); + + // Due to the form of the inequalities, sum of constants of the + // inequalities is (divisor - 1). + int64_t denominator = cst.atIneq(lbIneq, cst.getNumCols() - 1) + + cst.atIneq(ubIneq, cst.getNumCols() - 1) + 1; + + // Divisor should be positive. + if (denominator <= 0) + return failure(); + + // Check if coeff of variable is equal to divisor. + if (denominator != cst.atIneq(lbIneq, pos)) + return failure(); + + // Check if constraints are opposite of each other. Constant term + // is not required to be opposite and is not checked. + unsigned i = 0, e = 0; + for (i = 0, e = cst.getNumIds(); i < e; ++i) + if (cst.atIneq(ubIneq, i) != -cst.atIneq(lbIneq, i)) + break; + + if (i < e) + return failure(); + + // Set expr with dividend of the division. + SmallVector dividend(cst.getNumCols()); + for (i = 0, e = cst.getNumCols(); i < e; ++i) + if (i != pos) + dividend[i] = cst.atIneq(ubIneq, i); + expr = dividend; + + // Set divisor. + divisor = denominator; + + return success(); +} + /// Check if the pos^th identifier can be expressed as a floordiv of an affine -/// function of other identifiers (where the divisor is a positive constant), +/// function of other identifiers (where the divisor is a positive constant). /// `foundRepr` contains a boolean for each identifier indicating if the /// explicit representation for that identifier has already been computed. +/// Returns the upper-lower bound inequality pair using which the floordiv can +/// be computed. If the representation could be computed, `dividend` and +/// `denominator`. If the representation could not be computed, `llvm::None` is +/// returned. static Optional> computeSingleVarRepr(const FlatAffineConstraints &cst, - const SmallVector &foundRepr, unsigned pos) { + const SmallVector &foundRepr, unsigned pos, + SmallVector ÷nd, unsigned &divisor) { assert(pos < cst.getNumIds() && "invalid position"); assert(foundRepr.size() == cst.getNumIds() && "Size of foundRepr does not match total number of variables"); @@ -1366,55 +1442,20 @@ SmallVector lbIndices, ubIndices; cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices); - // `id` is equivalent to `expr floordiv divisor` if there - // are constraints of the form: - // 0 <= expr - divisor * id <= divisor - 1 - // Rearranging, we have: - // divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id' - // -divisor * id + expr >= 0 <-- Upper bound for 'id' - // - // For example: - // 32*k >= 16*i + j - 31 <-- Lower bound for 'k' - // 32*k <= 16*i + j <-- Upper bound for 'k' - // expr = 16*i + j, divisor = 32 - // k = ( 16*i + j ) floordiv 32 - // - // 4q >= i + j - 2 <-- Lower bound for 'q' - // 4q <= i + j + 1 <-- Upper bound for 'q' - // expr = i + j + 1, divisor = 4 - // q = (i + j + 1) floordiv 4 for (unsigned ubPos : ubIndices) { for (unsigned lbPos : lbIndices) { - // Due to the form of the inequalities, sum of constants of the - // inequalities is (divisor - 1). - int64_t divisor = cst.atIneq(lbPos, cst.getNumCols() - 1) + - cst.atIneq(ubPos, cst.getNumCols() - 1) + 1; - - // Divisor should be positive. - if (divisor <= 0) - continue; - - // Check if coeff of variable is equal to divisor. - if (divisor != cst.atIneq(lbPos, pos)) - continue; - - // Check if constraints are opposite of each other. Constant term - // is not required to be opposite and is not checked. - unsigned c = 0, f = 0; - for (c = 0, f = cst.getNumIds(); c < f; ++c) - if (cst.atIneq(ubPos, c) != -cst.atIneq(lbPos, c)) - break; - - if (c < f) + // Attempt to get divison representation from ubPos, lbPos. + if (failed(getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor))) continue; // Check if the inequalities depend on a variable for which // an explicit representation has not been found yet. // Exit to avoid circular dependencies between divisions. + unsigned c, f; for (c = 0, f = cst.getNumIds(); c < f; ++c) { if (c == pos) continue; - if (!foundRepr[c] && cst.atIneq(lbPos, c) != 0) + if (!foundRepr[c] && dividend[c] != 0) break; } @@ -1432,14 +1473,32 @@ return llvm::None; } -/// Find pairs of inequalities identified by their position indices, using -/// which an explicit representation for each local variable can be computed -/// The pairs are stored as indices of upperbound, lowerbound -/// inequalities. If no such pair can be found, it is stored as llvm::None. -void FlatAffineConstraints::getLocalReprLbUbPairs( +void FlatAffineConstraints::getLocalReprs( std::vector>> &repr) const { + std::vector> dividends(getNumLocalIds()); + SmallVector denominator(getNumLocalIds()); + getLocalReprs(dividends, denominator, repr); +} + +void FlatAffineConstraints::getLocalReprs( + std::vector> ÷nds, + SmallVector &denominator) const { + std::vector>> repr( + getNumLocalIds()); + getLocalReprs(dividends, denominator, repr); +} + +void FlatAffineConstraints::getLocalReprs( + std::vector> ÷nds, + SmallVector &denominator, + std::vector>> &repr) const { + assert(repr.size() == getNumLocalIds() && "Size of repr does not match number of local variables"); + assert(dividends.size() == getNumLocalIds() && + "Size of dividends does not match number of local variables"); + assert(denominator.size() == getNumLocalIds() && + "Size of denominators does not match number of local variables"); SmallVector foundRepr(getNumIds(), false); for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) @@ -1453,7 +1512,8 @@ changed = false; for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { if (!foundRepr[i + divOffset]) { - if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i)) { + if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i, + dividends[i], denominator[i])) { foundRepr[i + divOffset] = true; repr[i] = res; changed = true; @@ -1461,6 +1521,12 @@ } } } while (changed); + + // Set 0 denominator for identifiers for which no division representation + // could be found. + for (unsigned i = 0, e = repr.size(); i < e; ++i) + if (!repr[i].hasValue()) + denominator[i] = 0; } /// Tightens inequalities given that we are dealing with integer spaces. This is @@ -1736,36 +1802,19 @@ if (exprs[i]) foundRepr[i] = true; - auto ulPair = computeSingleVarRepr(cst, foundRepr, pos); + SmallVector dividend; + unsigned divisor; + auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); // No upper-lower bound pair found for this var. if (!ulPair) return false; - unsigned ubPos = ulPair->first; - - // Upper bound is of the form: - // -divisor * id + expr >= 0 - // where `id` is equivalent to `expr floordiv divisor`. - // - // Since the division cannot be dependent on itself, the coefficient of - // of `id` in `expr` is zero. The coefficient of `id` in the upperbound - // is -divisor. - int64_t divisor = -cst.atIneq(ubPos, pos); - int64_t constantTerm = cst.atIneq(ubPos, cst.getNumCols() - 1); - // Construct the dividend expression. - auto dividendExpr = getAffineConstantExpr(constantTerm, context); - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (c == pos) - continue; - int64_t ubVal = cst.atIneq(ubPos, c); - if (ubVal == 0) - continue; - // computeSingleVarRepr guarantees that expr is known here. - dividendExpr = dividendExpr + ubVal * exprs[c]; - } + auto dividendExpr = getAffineConstantExpr(dividend.back(), context); + for (unsigned c = 0, f = cst.getNumIds(); c < f; c++) + if (dividend[c] != 0) + dividendExpr = dividendExpr + dividend[c] * exprs[c]; // Successfully detected the floordiv. exprs[pos] = dividendExpr.floorDiv(divisor); @@ -1872,18 +1921,85 @@ equalities.resizeVertically(pos); } -/// Merge local ids of `this` and `other`. This is done by appending local ids -/// of `other` to `this` and inserting local ids of `this` to `other` at start -/// of its local ids. Number of dimension and symbol ids should match in -/// `this` and `other`. +/// Merge local identifer at `pos2` into local identifer at `pos1` in `fac`. +static void mergeDivision(FlatAffineConstraints &fac, unsigned pos1, + unsigned pos2) { + unsigned localOffset = fac.getNumDimAndSymbolIds(); + pos1 += localOffset; + pos2 += localOffset; + for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i) + fac.atIneq(i, pos1) += fac.atIneq(i, pos2); + for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i) + fac.atEq(i, pos1) += fac.atEq(i, pos2); + fac.removeId(pos2); +} + void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) { assert(getNumDimIds() == other.getNumDimIds() && "Number of dimension ids should match"); assert(getNumSymbolIds() == other.getNumSymbolIds() && "Number of symbol ids should match"); - unsigned initLocals = getNumLocalIds(); - insertLocalId(getNumLocalIds(), other.getNumLocalIds()); - other.insertLocalId(0, initLocals); + + FlatAffineConstraints &fac1 = *this; + FlatAffineConstraints &fac2 = other; + + // Get divisions inequality pairs from each FAC. + std::vector> divs1(fac1.getNumLocalIds()), + divs2(fac2.getNumLocalIds()); + SmallVector denoms1(fac1.getNumLocalIds()), + denoms2(fac2.getNumLocalIds()); + fac1.getLocalReprs(divs1, denoms1); + fac2.getLocalReprs(divs2, denoms2); + + // Merge local ids of fac1 and fac2 without using division information, + // i.e. append local ids of `fac2` to `fac1` and insert local ids of `fac1` + // to `fac2` at start of its local ids. + unsigned initLocals = fac1.getNumLocalIds(); + insertLocalId(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + fac2.insertLocalId(0, initLocals); + + // Merge division representation extracted from fac1 and fac2. + divs1.insert(divs1.end(), divs2.begin(), divs2.end()); + denoms1.insert(denoms1.end(), denoms2.begin(), denoms2.end()); + + auto dependsOnExist = [&](unsigned offset, SmallVector &div) { + for (unsigned i = offset, e = div.size(); i < e; ++i) + if (div[i] != 0) + return true; + return false; + }; + + // Find duplicate divisions and merge them. + // TODO: Add division normalization to support divisions that differ by + // a constant + for (unsigned i = 0; i < divs1.size(); ++i) { + // Check if a division exists which is duplicate of division at `i`. + for (unsigned j = i + 1; j < divs1.size(); ++j) { + // Check if division representation exists for both local ids. + if (denoms1[i] == 0 || denoms1[j] == 0) + continue; + // Check if denominators match. + if (denoms1[i] != denoms1[j]) + continue; + // Check if representation is equal. + if (!std::equal(divs1[i].begin(), divs1[i].end(), divs1[j].begin())) + continue; + // If division representation contains a local variable, do not match. + // TODO: Support divisions that depend on other local ids. This can + // be done by ordering divisions such that a division representation + // for local identifier at position `i` only depends on local identifiers + // at position < `i`. + if (dependsOnExist(fac1.getIdKindOffset(IdKind::Local), divs1[j])) + continue; + + // Merge divisions at position `j` into division at position `i`. + mergeDivision(fac1, i, j); + mergeDivision(fac2, i, j); + divs1.erase(divs1.begin() + j); + denoms1.erase(denoms1.begin() + j); + --j; + } + } } /// Removes local variables using equalities. Each equality is checked if it diff --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp --- a/mlir/lib/Analysis/PresburgerSet.cpp +++ b/mlir/lib/Analysis/PresburgerSet.cpp @@ -196,7 +196,7 @@ // the local variables of sI. std::vector>> repr( sI.getNumLocalIds()); - sI.getLocalReprLbUbPairs(repr); + sI.getLocalReprs(repr); // Add sI's locals to b, after b's locals. Also add b's locals to sI, before // sI's locals. diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -648,7 +648,7 @@ std::vector>> res( fac.getNumLocalIds(), llvm::None); - fac.getLocalReprLbUbPairs(res); + fac.getLocalReprs(res); // Check if all expected divisions are computed. for (unsigned i = 0, e = fac.getNumLocalIds(); i < e; ++i) @@ -795,4 +795,79 @@ EXPECT_TRUE(fac3.isEmpty()); } +TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) { + { + // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). + FlatAffineConstraints fac1(1, 0, 1); + fac1.addLocalFloorDiv({1, 0, 0}, 2); + fac1.addEquality({1, 0, -3, 0}); + fac1.addInequality({1, 1, 0, 1}); + + // (x) : (exists y = [x / 2], z : x = 5y). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); + fac2.addEquality({1, -5, 0}); + fac2.appendLocalId(); + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division matched + 2 unmatched local variables. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } + + { + // (x) : (exists z = [x / 5], y = [x / 2] : x = 3y). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 5); + fac1.addLocalFloorDiv({1, 0, 0}, 2); + fac1.addEquality({1, 0, -3, 0}); + + // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); + fac2.addLocalFloorDiv({1, 0, 0}, 5); + fac2.addEquality({1, 0, -5, 0}); + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 division matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } +} + +TEST(FlatAffineConstraintsTest, mergeDivisionsUnsupported) { + // Division merging for divisions depending on other local variables + // not yet supported. + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 2); + fac1.addLocalFloorDiv({1, 1, 0}, 3); + fac1.addInequality({-1, 1, 1, 0}); + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); + fac2.addLocalFloorDiv({1, 1, 0}, 3); + fac2.addInequality({1, -1, -1, 0}); + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division matched + 2 unmerged division due to division depending on + // other local variables. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); +} + } // namespace mlir