diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -352,22 +352,16 @@ Optional> containsPointNoLocal(ArrayRef point) const; - /// Find equality and pairs of inequality constraints identified by their - /// position indices, using which an explicit representation for each local - /// variable can be computed. The indices of the constraints are stored in - /// `MaybeLocalRepr` struct. If no such pair can be found, the kind attribute - /// in `MaybeLocalRepr` is set to None. + /// Returns a `DivisonRepr` representing the division representation of local + /// variables in the constraint system. /// - /// The dividends of the explicit representations are stored in `dividends` - /// and the denominators in `denominators`. If no explicit representation - /// could be found for the `i^th` local variable, `denominators[i]` is set - /// to 0. - void getLocalReprs(std::vector> ÷nds, - SmallVector &denominators, - std::vector &repr) const; - void getLocalReprs(std::vector &repr) const; - void getLocalReprs(std::vector> ÷nds, - SmallVector &denominators) const; + /// If `repr` is not `nullptr`, the equality and pairs of inequality + /// constraints identified by their position indices using which an explicit + /// representation for each local variable can be computed are set in `repr` + /// in the form of a `MaybeLocalRepr` struct. If no such inequality + /// pair/equality can be found, the kind attribute in `MaybeLocalRepr` is set + /// to None. + DivisionRepr getLocalReprs(std::vector *repr = nullptr) const; /// The type of bound: equal, lower bound or upper bound. enum BoundType { EQ, LB, UB }; diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -89,6 +89,9 @@ MutableArrayRef getRow(unsigned row); ArrayRef getRow(unsigned row) const; + /// Set the specified row to `elems`. + void setRow(unsigned row, ArrayRef elems); + /// Insert columns having positions pos, pos + 1, ... pos + count - 1. /// Columns that were at positions 0 to pos - 1 will stay where they are; /// columns that were at positions pos to nColumns - 1 will be pushed to the diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -17,6 +17,8 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "mlir/Analysis/Presburger/Matrix.h" + namespace mlir { namespace presburger { @@ -102,6 +104,77 @@ } repr; }; +/// Class storing divisions of a constraint system. The divisions are part of +/// the constraint system as variables. The divisions are stored in order: +/// [nonDivVars, divVars, constant]. Each division may or may not have a +/// representation. If the division does not have a reprsentation, the +/// dividend of the division has no meaning and the denominator is zero. +/// +/// The i^th division here, represents the division representation of the +/// variable at position `divOffset + i` in the constraint system. +class DivisionRepr { +public: + DivisionRepr(unsigned numVars, unsigned numDivs) + : dividends(numDivs, numVars + 1), denoms(numDivs, 0) {} + + DivisionRepr(unsigned numVars) : dividends(numVars + 1, 0) {} + + unsigned getNumVars() const { return dividends.getNumColumns() - 1; } + unsigned getNumDivs() const { return dividends.getNumRows(); } + unsigned getNumNonDivs() const { return getNumVars() - getNumDivs(); } + // Get the offset from where division variables start. + unsigned getDivOffset() const { return getNumVars() - getNumDivs(); } + + // Check whether the `i^th` division has a division representation or not. + bool hasDivRepr(unsigned i) const { return denoms[i] != 0; } + // Check whether all the divisions have a division representation or not. + bool hasAllDivReprs() const { + return all_of(denoms, [](unsigned denom) { return denom != 0; }); + } + + // Get the dividend of the `i^th` division. + MutableArrayRef getDividend(unsigned i) { + return dividends.getRow(i); + } + ArrayRef getDividend(unsigned i) const { + return dividends.getRow(i); + } + + // Get the `i^th` denominator. + unsigned &getDenom(unsigned i) { return denoms[i]; } + unsigned getDenom(unsigned i) const { return denoms[i]; } + + ArrayRef getDenoms() const { return denoms; } + + void setDividend(unsigned i, ArrayRef dividend) { + dividends.setRow(i, dividend); + } + + /// Removes duplicate divisions. On every possible duplicate division found, + /// `merge(i, j)`, where `i`, `j` are current index of the duplicate + /// divisions, is called and division at index `j` is merged into division at + /// index `i`. If `merge(i, j)` returns `true`, the divisions are merged i.e. + /// `j^th` division gets eliminated and it's each instance is replaced by + /// `i^th` division. If it returns `false`, the divisions are not merged. + /// `merge` can also do side effects, For example it can merge the local + /// variables in IntegerRelation. + void + removeDuplicateDivs(llvm::function_ref merge); + + void print(raw_ostream &os) const; + void dump() const; + +private: + /// Each row of the Matrix represents a single division dividend. The + /// `i^th` row represents the dividend of the variable at `divOffset + i` + /// in the constraint system (and the `i^th` division variable). + Matrix dividends; + + /// Denominators of each division. If a denominator of a division is `0`, the + /// division variable is considered to not have a division representation. + SmallVector denoms; +}; + /// If `q` is defined to be equal to `expr floordiv d`, this equivalent to /// saying that `q` is an integer and `q` is subject to the inequalities /// `0 <= expr - d*q <= c - 1` (quotient remainder theorem). @@ -135,25 +208,9 @@ /// `MaybeLocalRepr` is set to None. MaybeLocalRepr computeSingleVarRepr(const IntegerRelation &cst, ArrayRef foundRepr, unsigned pos, - SmallVector ÷nd, + MutableArrayRef dividend, unsigned &divisor); -/// Given dividends of divisions `divs` and denominators `denoms`, detects and -/// removes duplicate divisions. `localOffset` is the offset in dividend of a -/// division from where local variables start. -/// -/// On every possible duplicate division found, `merge(i, j)`, where `i`, `j` -/// are current index of the duplicate divisions, is called and division at -/// index `j` is merged into division at index `i`. If `merge(i, j)` returns -/// `true`, the divisions are merged i.e. `j^th` division gets eliminated and -/// it's each instance is replaced by `i^th` division. If it returns `false`, -/// the divisions are not merged. `merge` can also do side effects, For example -/// it can merge the local variables in IntegerRelation. -void removeDuplicateDivs( - std::vector> &divs, - SmallVectorImpl &denoms, unsigned localOffset, - llvm::function_ref merge); - /// Given two relations, A and B, add additional local vars to the sets such /// that both have the union of the local vars in each set, without changing /// the set of points that lie in A and B. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -175,8 +175,8 @@ // // Take a copy so we can perform mutations. IntegerRelation copy = *this; - std::vector reprs; - copy.getLocalReprs(reprs); + std::vector reprs(getNumLocalVars()); + copy.getLocalReprs(&reprs); // Iterate through all the locals. The last `numNonDivLocals` are the locals // that have been scanned already and do not have division representations. @@ -907,33 +907,14 @@ return copy.findIntegerSample(); } -void IntegerRelation::getLocalReprs(std::vector &repr) const { - std::vector> dividends(getNumLocalVars()); - SmallVector denominators(getNumLocalVars()); - getLocalReprs(dividends, denominators, repr); -} - -void IntegerRelation::getLocalReprs( - std::vector> ÷nds, - SmallVector &denominators) const { - std::vector repr(getNumLocalVars()); - getLocalReprs(dividends, denominators, repr); -} - -void IntegerRelation::getLocalReprs( - std::vector> ÷nds, - SmallVector &denominators, - std::vector &repr) const { - - repr.resize(getNumLocalVars()); - dividends.resize(getNumLocalVars()); - denominators.resize(getNumLocalVars()); - +DivisionRepr +IntegerRelation::getLocalReprs(std::vector *repr) const { SmallVector foundRepr(getNumVars(), false); for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) foundRepr[i] = true; - unsigned divOffset = getNumDimAndSymbolVars(); + unsigned divOffset = getVarKindOffset(VarKind::Local); + DivisionRepr divs(getNumVars(), getNumLocalVars()); bool changed; do { // Each time changed is true, at end of this iteration, one or more local @@ -941,22 +922,24 @@ changed = false; for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) { if (!foundRepr[i + divOffset]) { - MaybeLocalRepr res = computeSingleVarRepr( - *this, foundRepr, divOffset + i, dividends[i], denominators[i]); - if (!res) + MaybeLocalRepr res = + computeSingleVarRepr(*this, foundRepr, divOffset + i, + divs.getDividend(i), divs.getDenom(i)); + if (!res) { + // No representation was found, so set denominator to be 0 and + // continue. + divs.getDenom(i) = 0; continue; + } foundRepr[i + divOffset] = true; - repr[i] = res; + if (repr) + (*repr)[i] = res; changed = true; } } } while (changed); - // Set 0 denominator for variables for which no division representation - // could be found. - for (unsigned i = 0, e = repr.size(); i < e; ++i) - if (!repr[i]) - denominators[i] = 0; + return divs; } /// Tightens inequalities given that we are dealing with integer spaces. This is @@ -1206,23 +1189,16 @@ } bool IntegerRelation::hasOnlyDivLocals() const { - std::vector reprs; - getLocalReprs(reprs); - return llvm::all_of(reprs, - [](const MaybeLocalRepr &repr) { return bool(repr); }); + return getLocalReprs().hasAllDivReprs(); } void IntegerRelation::removeDuplicateDivs() { - std::vector> divs; - SmallVector denoms; - - getLocalReprs(divs, denoms); + DivisionRepr divs = getLocalReprs(); auto merge = [this](unsigned i, unsigned j) -> bool { eliminateRedundantLocalVar(i, j); return true; }; - presburger::removeDuplicateDivs(divs, denoms, - getVarKindOffset(VarKind::Local), merge); + divs.removeDuplicateDivs(merge); } /// Removes local variables using equalities. Each equality is checked if it diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -92,6 +92,12 @@ return {&data[row * nReservedColumns], nColumns}; } +void Matrix::setRow(unsigned row, ArrayRef elems) { + assert(elems.size() == getNumColumns() && "elems must match row length!"); + for (unsigned i = 0, e = getNumColumns(); i < e; ++i) + at(row, i) = elems[i]; +} + void Matrix::insertColumn(unsigned pos) { insertColumns(pos, 1); } void Matrix::insertColumns(unsigned pos, unsigned count) { if (count == 0) diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -253,10 +253,8 @@ // // Careful! This has to be done after the merge above; otherwise, the // dividends won't contain the new ids inserted during the merge. - std::vector repr; - std::vector> dividends; - SmallVector divisors; - sI.getLocalReprs(dividends, divisors, repr); + std::vector repr(sI.getNumLocalVars()); + DivisionRepr divs = sI.getLocalReprs(&repr); // Mark which inequalities of sI are division inequalities and add all // such inequalities to b. @@ -301,10 +299,10 @@ // not be because they were never a part of sI; we just infer them // from the equality and add them only to b. b.addInequality( - getDivLowerBound(dividends[i], divisors[i], + getDivLowerBound(divs.getDividend(i), divs.getDenom(i), sI.getVarKindOffset(VarKind::Local) + i)); b.addInequality( - getDivUpperBound(dividends[i], divisors[i], + getDivUpperBound(divs.getDividend(i), divs.getDenom(i), sI.getVarKindOffset(VarKind::Local) + i)); } } diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -21,7 +21,7 @@ /// Normalize a division's `dividend` and the `divisor` by their GCD. For /// example: if the dividend and divisor are [2,0,4] and 4 respectively, /// they get normalized to [1,0,2] and 2. -static void normalizeDivisionByGCD(SmallVectorImpl ÷nd, +static void normalizeDivisionByGCD(MutableArrayRef dividend, unsigned &divisor) { if (divisor == 0 || dividend.empty()) return; @@ -89,7 +89,7 @@ /// normalized by GCD. static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned ubIneq, unsigned lbIneq, - SmallVector &expr, + MutableArrayRef expr, unsigned &divisor) { assert(pos <= cst.getNumVars() && "Invalid variable position"); @@ -97,6 +97,7 @@ "Invalid upper bound inequality position"); assert(lbIneq <= cst.getNumInequalities() && "Invalid upper bound inequality position"); + assert(expr.size() == cst.getNumCols() && "Invalid expression size"); // Extract divisor from the lower bound. divisor = cst.atIneq(lbIneq, pos); @@ -126,7 +127,6 @@ // The inequality pair can be used to extract the division. // Set `expr` to the dividend of the division except the constant term, which // is set below. - expr.resize(cst.getNumCols(), 0); for (i = 0, e = cst.getNumVars(); i < e; ++i) if (i != pos) expr[i] = cst.atIneq(ubIneq, i); @@ -152,11 +152,12 @@ /// set to the denominator of the division. The final division expression is /// normalized by GCD. static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, - unsigned eqInd, SmallVector &expr, + unsigned eqInd, MutableArrayRef expr, unsigned &divisor) { assert(pos <= cst.getNumVars() && "Invalid variable position"); assert(eqInd <= cst.getNumEqualities() && "Invalid equality position"); + assert(expr.size() == cst.getNumCols() && "Invalid expression size"); // Extract divisor, the divisor can be negative and hence its sign information // is stored in `signDiv` to reverse the sign of dividend's coefficients. @@ -169,7 +170,6 @@ // The divisor is always a positive integer. divisor = tempDiv * signDiv; - expr.resize(cst.getNumCols(), 0); for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i) if (i != pos) expr[i] = -signDiv * cst.atEq(eqInd, i); @@ -215,7 +215,7 @@ /// `MaybeLocalRepr` is set to None. MaybeLocalRepr presburger::computeSingleVarRepr( const IntegerRelation &cst, ArrayRef foundRepr, unsigned pos, - SmallVector ÷nd, unsigned &divisor) { + MutableArrayRef dividend, unsigned &divisor) { assert(pos < cst.getNumVars() && "invalid position"); assert(foundRepr.size() == cst.getNumVars() && "Size of foundRepr does not match total number of variables"); @@ -261,57 +261,6 @@ return vec; } -void presburger::removeDuplicateDivs( - std::vector> &divs, - SmallVectorImpl &denoms, unsigned localOffset, - llvm::function_ref merge) { - - // Find and merge duplicate divisions. - // TODO: Add division normalization to support divisions that differ by - // a constant. - // TODO: Add division ordering such that a division representation for local - // variable at position `i` only depends on local variables at position < - // `i`. This would make sure that all divisions depending on other local - // variables that can be merged, are merged. - for (unsigned i = 0; i < divs.size(); ++i) { - // Check if a division representation exists for the `i^th` local var. - if (denoms[i] == 0) - continue; - // Check if a division exists which is a duplicate of the division at `i`. - for (unsigned j = i + 1; j < divs.size(); ++j) { - // Check if a division representation exists for the `j^th` local var. - if (denoms[j] == 0) - continue; - // Check if the denominators match. - if (denoms[i] != denoms[j]) - continue; - // Check if the representations are equal. - if (divs[i] != divs[j]) - continue; - - // Merge divisions at position `j` into division at position `i`. If - // merge fails, do not merge these divs. - bool mergeResult = merge(i, j); - if (!mergeResult) - continue; - - // Update division information to reflect merging. - for (unsigned k = 0, g = divs.size(); k < g; ++k) { - SmallVector &div = divs[k]; - if (denoms[k] != 0) { - div[localOffset + i] += div[localOffset + j]; - div.erase(div.begin() + localOffset + j); - } - } - - divs.erase(divs.begin() + j); - denoms.erase(denoms.begin() + j); - // Since `j` can never be zero, we do not need to worry about overflows. - --j; - } - } -} - void presburger::mergeLocalVars( IntegerRelation &relA, IntegerRelation &relB, llvm::function_ref merge) { @@ -327,23 +276,17 @@ relB.insertVar(VarKind::Local, 0, initLocals); // Get division representations from each rel. - std::vector> divsA, divsB; - SmallVector denomsA, denomsB; - relA.getLocalReprs(divsA, denomsA); - relB.getLocalReprs(divsB, denomsB); - - // Copy division information for relB into `divsA` and `denomsA`, so that - // these have the combined division information of both rels. Since newly - // added local variables in relA and relB have no constraints, they will not - // have any division representation. - std::copy(divsB.begin() + initLocals, divsB.end(), - divsA.begin() + initLocals); - std::copy(denomsB.begin() + initLocals, denomsB.end(), - denomsA.begin() + initLocals); - - // Merge all divisions by removing duplicate divisions. - unsigned localOffset = relA.getVarKindOffset(VarKind::Local); - presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge); + DivisionRepr divsA = relA.getLocalReprs(); + DivisionRepr divsB = relB.getLocalReprs(); + + for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) { + divsA.setDividend(i, divsB.getDividend(i)); + divsA.getDenom(i) = divsB.getDenom(i); + } + + // Remove duplicate divisions from divsA. The removing duplicate divisions + // call, calls `merge` to effectively merge divisions in relA and relB. + divsA.removeDuplicateDivs(merge); } SmallVector presburger::getDivUpperBound(ArrayRef dividend, @@ -412,3 +355,59 @@ --coeffs.back(); return coeffs; } + +void DivisionRepr::removeDuplicateDivs( + llvm::function_ref merge) { + + // Find and merge duplicate divisions. + // TODO: Add division normalization to support divisions that differ by + // a constant. + // TODO: Add division ordering such that a division representation for local + // variable at position `i` only depends on local variables at position < + // `i`. This would make sure that all divisions depending on other local + // variables that can be merged, are merged. + for (unsigned i = 0; i < getNumDivs(); ++i) { + // Check if a division representation exists for the `i^th` local var. + if (denoms[i] == 0) + continue; + // Check if a division exists which is a duplicate of the division at `i`. + for (unsigned j = i + 1; j < getNumDivs(); ++j) { + // Check if a division representation exists for the `j^th` local var. + if (denoms[j] == 0) + continue; + // Check if the denominators match. + if (denoms[i] != denoms[j]) + continue; + // Check if the representations are equal. + if (dividends.getRow(i) != dividends.getRow(j)) + continue; + + // Merge divisions at position `j` into division at position `i`. If + // merge fails, do not merge these divs. + bool mergeResult = merge(i, j); + if (!mergeResult) + continue; + + // Update division information to reflect merging. + unsigned divOffset = getDivOffset(); + dividends.addToColumn(divOffset + j, divOffset + i, /*scale=*/1); + dividends.removeColumn(divOffset + j); + dividends.removeRow(j); + denoms.erase(denoms.begin() + j); + + // Since `j` can never be zero, we do not need to worry about overflows. + --j; + } + } +} + +void DivisionRepr::print(raw_ostream &os) const { + os << "Dividends:\n"; + dividends.print(os); + os << "Denominators\n"; + for (unsigned i = 0, e = denoms.size(); i < e; ++i) + os << denoms[i] << " "; + os << "\n"; +} + +void DivisionRepr::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -865,7 +865,7 @@ if (exprs[i]) foundRepr[i] = true; - SmallVector dividend; + SmallVector dividend(cst.getNumCols()); unsigned divisor; auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -608,23 +608,19 @@ static void checkDivisionRepresentation( IntegerPolyhedron &poly, const std::vector> &expectedDividends, - const SmallVectorImpl &expectedDenominators) { - std::vector> dividends; - SmallVector denominators; - - poly.getLocalReprs(dividends, denominators); + ArrayRef expectedDenominators) { + DivisionRepr divs = poly.getLocalReprs(); // Check that the `denominators` and `expectedDenominators` match. - EXPECT_TRUE(expectedDenominators == denominators); + EXPECT_TRUE(expectedDenominators == divs.getDenoms()); // Check that the `dividends` and `expectedDividends` match. If the // denominator for a division is zero, we ignore its dividend. - EXPECT_TRUE(dividends.size() == expectedDividends.size()); - for (unsigned i = 0, e = dividends.size(); i < e; ++i) { - if (denominators[i] != 0) { - EXPECT_TRUE(expectedDividends[i] == dividends[i]); - } - } + EXPECT_TRUE(divs.getNumDivs() == expectedDividends.size()); + for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) + if (divs.hasDivRepr(i)) + for (unsigned j = 0, f = divs.getNumVars() + 1; j < f; ++j) + EXPECT_TRUE(expectedDividends[i][j] == divs.getDividend(i)[j]); } TEST(IntegerPolyhedronTest, computeLocalReprSimple) {