diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -157,11 +157,21 @@ bool containsPoint(ArrayRef point) const; /// Find pairs of inequalities identified by their position indices, using - /// which an explicit representation for each local variable can be computed - /// The pairs are stored as indices of upperbound, lowerbound - /// inequalities. If no such pair can be found, it is stored as llvm::None. - void getLocalReprLbUbPairs( + /// which an explicit representation for each local variable can be computed. + /// The pairs are stored as indices of upperbound, lowerbound inequalities. If + /// no such pair can be found, it is stored as llvm::None. + /// + /// The dividends of the explicit representations are stored in `dividends` + /// and the denominators in `denominators`. `denominators[i]` is set to 0, if + /// no explicit representation could be found for the `i^th` local identifier. + void getLocalReprs( + std::vector> ÷nds, + SmallVector &denominators, + std::vector>> &repr) const; + void getLocalReprs( std::vector>> &repr) const; + void getLocalReprs(std::vector> ÷nds, + SmallVector &denominators) const; // Clones this object. std::unique_ptr clone() const; @@ -430,10 +440,9 @@ /// variables. void convertDimToLocal(unsigned dimStart, unsigned dimLimit); - /// Merge local ids of `this` and `other`. This is done by appending local ids - /// of `other` to `this` and inserting local ids of `this` to `other` at start - /// of its local ids. Number of dimension and symbol ids should match in - /// `this` and `other`. + /// Merges and aligns local ids of `this` and `other`. Local ids with + /// identical division representations are merged. The number of dimensions + /// and symbol ids in `this` and `other` should match. void mergeLocalIds(FlatAffineConstraints &other); /// Removes all equalities and inequalities. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1424,12 +1424,17 @@ } /// Check if the pos^th identifier can be expressed as a floordiv of an affine -/// function of other identifiers (where the divisor is a positive constant), +/// function of other identifiers (where the divisor is a positive constant). /// `foundRepr` contains a boolean for each identifier indicating if the /// explicit representation for that identifier has already been computed. +/// Returns the upper-lower bound inequality pair using which the floordiv can +/// be computed. If the representation could be computed, `dividend` and +/// `denominator`. If the representation could not be computed, `llvm::None` is +/// returned. static Optional> computeSingleVarRepr(const FlatAffineConstraints &cst, - const SmallVector &foundRepr, unsigned pos) { + const SmallVector &foundRepr, unsigned pos, + SmallVector ÷nd, unsigned &divisor) { assert(pos < cst.getNumIds() && "invalid position"); assert(foundRepr.size() == cst.getNumIds() && "Size of foundRepr does not match total number of variables"); @@ -1440,9 +1445,7 @@ for (unsigned ubPos : ubIndices) { for (unsigned lbPos : lbIndices) { // Attempt to get divison representation from ubPos, lbPos. - SmallVector expr; - unsigned divisor; - if (failed(getDivRepr(cst, pos, ubPos, lbPos, expr, divisor))) + if (failed(getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor))) continue; // Check if the inequalities depend on a variable for which @@ -1452,7 +1455,7 @@ for (c = 0, f = cst.getNumIds(); c < f; ++c) { if (c == pos) continue; - if (!foundRepr[c] && expr[c] != 0) + if (!foundRepr[c] && dividend[c] != 0) break; } @@ -1470,14 +1473,29 @@ return llvm::None; } -/// Find pairs of inequalities identified by their position indices, using -/// which an explicit representation for each local variable can be computed -/// The pairs are stored as indices of upperbound, lowerbound -/// inequalities. If no such pair can be found, it is stored as llvm::None. -void FlatAffineConstraints::getLocalReprLbUbPairs( +void FlatAffineConstraints::getLocalReprs( std::vector>> &repr) const { - assert(repr.size() == getNumLocalIds() && - "Size of repr does not match number of local variables"); + std::vector> dividends(getNumLocalIds()); + SmallVector denominators(getNumLocalIds()); + getLocalReprs(dividends, denominators, repr); +} + +void FlatAffineConstraints::getLocalReprs( + std::vector> ÷nds, + SmallVector &denominators) const { + std::vector>> repr( + getNumLocalIds()); + getLocalReprs(dividends, denominators, repr); +} + +void FlatAffineConstraints::getLocalReprs( + std::vector> ÷nds, + SmallVector &denominators, + std::vector>> &repr) const { + + repr.resize(getNumLocalIds()); + dividends.resize(getNumLocalIds()); + denominators.resize(getNumLocalIds()); SmallVector foundRepr(getNumIds(), false); for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) @@ -1491,7 +1509,8 @@ changed = false; for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { if (!foundRepr[i + divOffset]) { - if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i)) { + if (auto res = computeSingleVarRepr(*this, foundRepr, divOffset + i, + dividends[i], denominators[i])) { foundRepr[i + divOffset] = true; repr[i] = res; changed = true; @@ -1499,6 +1518,12 @@ } } } while (changed); + + // Set 0 denominator for identifiers for which no division representation + // could be found. + for (unsigned i = 0, e = repr.size(); i < e; ++i) + if (!repr[i].hasValue()) + denominators[i] = 0; } /// Tightens inequalities given that we are dealing with integer spaces. This is @@ -1774,36 +1799,19 @@ if (exprs[i]) foundRepr[i] = true; - auto ulPair = computeSingleVarRepr(cst, foundRepr, pos); + SmallVector dividend; + unsigned divisor; + auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); // No upper-lower bound pair found for this var. if (!ulPair) return false; - unsigned ubPos = ulPair->first; - - // Upper bound is of the form: - // -divisor * id + expr >= 0 - // where `id` is equivalent to `expr floordiv divisor`. - // - // Since the division cannot be dependent on itself, the coefficient of - // of `id` in `expr` is zero. The coefficient of `id` in the upperbound - // is -divisor. - int64_t divisor = -cst.atIneq(ubPos, pos); - int64_t constantTerm = cst.atIneq(ubPos, cst.getNumCols() - 1); - // Construct the dividend expression. - auto dividendExpr = getAffineConstantExpr(constantTerm, context); - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (c == pos) - continue; - int64_t ubVal = cst.atIneq(ubPos, c); - if (ubVal == 0) - continue; - // computeSingleVarRepr guarantees that expr is known here. - dividendExpr = dividendExpr + ubVal * exprs[c]; - } + auto dividendExpr = getAffineConstantExpr(dividend.back(), context); + for (unsigned c = 0, f = cst.getNumIds(); c < f; c++) + if (dividend[c] != 0) + dividendExpr = dividendExpr + dividend[c] * exprs[c]; // Successfully detected the floordiv. exprs[pos] = dividendExpr.floorDiv(divisor); @@ -1910,18 +1918,102 @@ equalities.resizeVertically(pos); } -/// Merge local ids of `this` and `other`. This is done by appending local ids -/// of `other` to `this` and inserting local ids of `this` to `other` at start -/// of its local ids. Number of dimension and symbol ids should match in -/// `this` and `other`. +/// Eliminate `pos2^th` local identifier, replacing its every instance with +/// `pos1^th` local identifier. This function is intended to be used to remove +/// redundancy when local variables at position `pos1` and `pos2` are restricted +/// to have the same value. +static void eleminateRedundantLocalId(FlatAffineConstraints &fac, unsigned pos1, + unsigned pos2) { + + assert(pos1 < fac.getNumLocalIds() && "Invalid local id position"); + assert(pos2 < fac.getNumLocalIds() && "Invalid local id position"); + + unsigned localOffset = fac.getNumDimAndSymbolIds(); + pos1 += localOffset; + pos2 += localOffset; + for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i) + fac.atIneq(i, pos1) += fac.atIneq(i, pos2); + for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i) + fac.atEq(i, pos1) += fac.atEq(i, pos2); + fac.removeId(pos2); +} + void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) { assert(getNumDimIds() == other.getNumDimIds() && "Number of dimension ids should match"); assert(getNumSymbolIds() == other.getNumSymbolIds() && "Number of symbol ids should match"); - unsigned initLocals = getNumLocalIds(); - insertLocalId(getNumLocalIds(), other.getNumLocalIds()); - other.insertLocalId(0, initLocals); + + FlatAffineConstraints &fac1 = *this; + FlatAffineConstraints &fac2 = other; + + // Get division representations from each FAC. + std::vector> divs1, divs2; + SmallVector denoms1, denoms2; + fac1.getLocalReprs(divs1, denoms1); + fac2.getLocalReprs(divs2, denoms2); + + // Merge local ids of fac1 and fac2 without using division information, + // i.e. append local ids of `fac2` to `fac1` and insert local ids of `fac1` + // to `fac2` at start of its local ids. Also, insert these local ids in + // division representation. + unsigned initLocals = fac1.getNumLocalIds(); + for (unsigned i = 0, e = divs1.size(); i < e; ++i) + if (denoms1[i] != 0) + divs1[i].insert(divs1[i].begin() + fac1.getNumIds(), + fac2.getNumLocalIds(), 0); + insertLocalId(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + for (unsigned i = 0, e = divs2.size(); i < e; ++i) + if (denoms2[i] != 0) + divs2[i].insert(divs2[i].begin() + fac2.getIdKindOffset(IdKind::Local), + initLocals, 0); + fac2.insertLocalId(0, initLocals); + + // Merge division representations extracted from fac1 and fac2. + divs1.insert(divs1.end(), divs2.begin(), divs2.end()); + denoms1.insert(denoms1.end(), denoms2.begin(), denoms2.end()); + + // Find and merge duplicate divisions. + // TODO: Add division normalization to support divisions that differ by + // a constant. + // TODO: Add division ordering such that a division representation for local + // identifier at position `i` only depends on local identifiers at position < + // `i`. This makes sure that all divisions depending on other local variables + // that can be merged, are merged. + unsigned localOffset = getIdKindOffset(IdKind::Local); + for (unsigned i = 0; i < divs1.size(); ++i) { + // Check if division representations exists `i^th` local id. + if (denoms1[i] == 0) + continue; + // Check if a division exists which is a duplicate of the division at `i`. + for (unsigned j = i + 1; j < divs1.size(); ++j) { + // Check if division representations exists for `j^th` local id. + if (denoms1[j] == 0) + continue; + // Check if the denominators match. + if (denoms1[i] != denoms1[j]) + continue; + // Check if the representations are equal. + if (!std::equal(divs1[i].begin(), divs1[i].end(), divs1[j].begin())) + continue; + + // Merge divisions at position `j` into division at position `i`. + eleminateRedundantLocalId(fac1, i, j); + eleminateRedundantLocalId(fac2, i, j); + for (unsigned k = 0, g = divs1.size(); k < g; ++k) { + SmallVector &div = divs1[k]; + if (denoms1[k] != 0) { + div[localOffset + i] += div[localOffset + j]; + div.erase(div.begin() + localOffset + j); + } + } + + divs1.erase(divs1.begin() + j); + denoms1.erase(denoms1.begin() + j); + --j; + } + } } /// Removes local variables using equalities. Each equality is checked if it diff --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp --- a/mlir/lib/Analysis/PresburgerSet.cpp +++ b/mlir/lib/Analysis/PresburgerSet.cpp @@ -196,7 +196,7 @@ // the local variables of sI. std::vector>> repr( sI.getNumLocalIds()); - sI.getLocalReprLbUbPairs(repr); + sI.getLocalReprs(repr); // Add sI's locals to b, after b's locals. Also add b's locals to sI, before // sI's locals. diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -648,7 +648,7 @@ std::vector>> res( fac.getNumLocalIds(), llvm::None); - fac.getLocalReprLbUbPairs(res); + fac.getLocalReprs(res); // Check if all expected divisions are computed. for (unsigned i = 0, e = fac.getNumLocalIds(); i < e; ++i) @@ -795,4 +795,127 @@ EXPECT_TRUE(fac3.isEmpty()); } +TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) { + { + // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). + FlatAffineConstraints fac1(1, 0, 1); + fac1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2]. + fac1.addEquality({1, 0, -3, 0}); // x = 3y. + fac1.addInequality({1, 1, 0, 1}); // x + z + 1 >= 0. + + // (x) : (exists y = [x / 2], z : x = 5y). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addEquality({1, -5, 0}); // x = 5y. + fac2.appendLocalId(); // Add local id z. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division should be matched + 2 unmatched local ids. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } + + { + // (x) : (exists z = [x / 5], y = [x / 2] : x = 3y). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 5); // z = [x / 5]. + fac1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2]. + fac1.addEquality({1, 0, -3, 0}); // x = 3y. + + // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addLocalFloorDiv({1, 0, 0}, 5); // z = [x / 5]. + fac2.addEquality({1, 0, -5, 0}); // x = 5z. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } +} + +TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) { + { + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } + + { + // (x) : (exists y = [x / 2], z = [x + y / 3], z = [z + 1 / 5]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac1.addLocalFloorDiv({0, 0, 1, 1}, 5); // z = [z + 1 / 5]. + fac1.addInequality({-1, 1, 1, 0, 0}); // y + z >= x. + + // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac2.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5]. + fac2.addInequality({1, -1, -1, 0, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 3 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } +} + +TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) { + { + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 3]. + fac1.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 3]. + fac2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } +} + } // namespace mlir