diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -128,6 +128,8 @@ /// Add `scale` multiples of the source row to the target row. void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale); + /// Add `scale` multiples of the rowVec row to the specified row. + void addToRow(unsigned row, ArrayRef rowVec, int64_t scale); /// Add `scale` multiples of the source column to the target column. void addToColumn(unsigned sourceColumn, unsigned targetColumn, int64_t scale); diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -22,94 +22,93 @@ namespace mlir { namespace presburger { -/// This class represents a multi-affine function whose domain is given by an -/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a -/// tuple of integer values attached to every point in the polyhedron, with the -/// value of each element of the tuple given by an affine expression in the vars -/// of the polyhedron. For example we could have the domain -/// -/// (x, y) : (x >= 5, y >= x) -/// -/// and a tuple of three integers defined at every point in the polyhedron: +/// This class represents a multi-affine function with the domain as Z^d, where +/// `d` is the number of domain variables of the function. For example: /// /// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y). /// -/// In this way every point in the polyhedron has a tuple of integers associated -/// with it. If the integer polyhedron has local vars, then the output -/// expressions can use them as well. The output expressions are represented as -/// a matrix with one row for every element in the output vector one column for -/// each var, and an extra column at the end for the constant term. +/// The output expressions are represented as a matrix with one row for every +/// output, one column for each var including division variables, and an extra +/// column at the end for the constant term. /// /// Checking equality of two such functions is supported, as well as finding the /// value of the function at a specified point. class MultiAffineFunction { public: - MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output) - : domainSet(domain), output(output) {} - MultiAffineFunction(const Matrix &output, const PresburgerSpace &space) - : domainSet(space), output(output) {} - - unsigned getNumInputs() const { return domainSet.getNumDimAndSymbolVars(); } - unsigned getNumOutputs() const { return output.getNumRows(); } - bool isConsistent() const { - return output.getNumColumns() == domainSet.getNumVars() + 1; + MultiAffineFunction(const PresburgerSpace &space, const Matrix &output) + : space(space), output(output), + divs(space.getNumVars() - space.getNumRangeVars()) { + assertIsConsistent(); + } + + MultiAffineFunction(const PresburgerSpace &space, const Matrix &output, + const DivisionRepr &divs) + : space(space), output(output), divs(divs) { + assertIsConsistent(); } - /// Get the space of the input domain of this function. - const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); } - /// Get the input domain of this function. - const IntegerPolyhedron &getDomain() const { return domainSet; } + unsigned getNumDomainVars() const { return space.getNumDomainVars(); } + unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); } + unsigned getNumOutputs() const { return space.getNumRangeVars(); } + unsigned getNumDivs() const { return space.getNumLocalVars(); } + + /// Get the space of this function. + const PresburgerSpace &getSpace() const { return space; } + /// Get the domain/output space of the function. The returned space is a set + /// space. + PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); } + PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); } + /// Get a matrix with each row representing row^th output expression. const Matrix &getOutputMatrix() const { return output; } /// Get the `i^th` output expression. ArrayRef getOutputExpr(unsigned i) const { return output.getRow(i); } - /// Insert `num` variables of the specified kind at position `pos`. - /// Positions are relative to the kind of variable. The coefficient columns - /// corresponding to the added variables are initialized to zero. Return the - /// absolute column position (i.e., not relative to the kind of variable) - /// of the first added variable. - unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1); - - /// Remove the specified range of vars. - void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit); - - /// Given a MAF `other`, merges local variables such that both funcitons - /// have union of local vars, without changing the set of points in domain or - /// the output. - void mergeLocalVars(MultiAffineFunction &other); - - /// Return whether the outputs of `this` and `other` agree wherever both - /// functions are defined, i.e., the outputs should be equal for all points in - /// the intersection of the domains. - bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const; - - /// Return whether the `this` and `other` are equal. This is the case if - /// they lie in the same space, i.e. have the same dimensions, and their - /// domains are identical and their outputs are equal on their domain. + // Remove the specified range of outputs. + void removeOutputs(unsigned start, unsigned end); + + /// Given a MAF `other`, merges division variables such that both functions + /// have the union of the division vars that exist in the functions. + void mergeDivs(MultiAffineFunction &other); + + /// Return the output of the function at the given point. + SmallVector valueAt(ArrayRef point) const; + + /// Return whether the `this` and `other` are equal when the domain is + /// restricted to `domain`. This is the case if they lie in the same space, + /// and their outputs are equal for every point in `domain`. bool isEqual(const MultiAffineFunction &other) const; + bool isEqual(const MultiAffineFunction &other, + const IntegerPolyhedron &domain) const; + bool isEqual(const MultiAffineFunction &other, + const PresburgerSet &domain) const; - /// Get the value of the function at the specified point. If the point lies - /// outside the domain, an empty optional is returned. - Optional> valueAt(ArrayRef point) const; + void subtract(const MultiAffineFunction &other); - /// Truncate the output dimensions to the first `count` dimensions. - /// - /// TODO: refactor so that this can be accomplished through removeVarRange. - void truncateOutput(unsigned count); + /// Get this function as a relation. + IntegerRelation getAsRelation() const; void print(raw_ostream &os) const; void dump() const; private: - /// The IntegerPolyhedron representing the domain over which the function is - /// defined. - IntegerPolyhedron domainSet; + /// Assert that the MAF is consistent. + void assertIsConsistent() const; + + /// The space of this function. The domain variables are considered as the + /// input variables of the function. The range variables are considered as + /// the outputs. The symbols parametrize the function and locals are used to + /// represent divisions. Each local variable has a corressponding division + /// representation stored in `divs`. + PresburgerSpace space; /// The function's output is a tuple of integers, with the ith element of the /// tuple defined by the affine expression given by the ith row of this output /// matrix. Matrix output; + + /// Storage for division representation for each local variable in space. + DivisionRepr divs; }; /// This class represents a piece-wise MultiAffineFunction. This can be thought @@ -132,33 +131,47 @@ /// finding the value of the function at a point. class PWMAFunction { public: - PWMAFunction(const PresburgerSpace &space, unsigned numOutputs) - : space(space), numOutputs(numOutputs) { - assert(space.getNumDomainVars() == 0 && - "Set type space should have zero domain vars."); + struct Piece { + PresburgerSet domain; + MultiAffineFunction output; + + bool isConsistent() const { + return domain.getSpace().isCompatible(output.getDomainSpace()); + } + }; + + PWMAFunction(const PresburgerSpace &space) : space(space) { assert(space.getNumLocalVars() == 0 && "PWMAFunction cannot have local vars."); - assert(numOutputs >= 1 && "The function must output something!"); } + // Get the space of this function. const PresburgerSpace &getSpace() const { return space; } - void addPiece(const MultiAffineFunction &piece); - void addPiece(const IntegerPolyhedron &domain, const Matrix &output); - void addPiece(const PresburgerSet &domain, const Matrix &output); + // Add a piece ([domain, output] pair) to this function. + void addPiece(const Piece &piece); - const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; } unsigned getNumPieces() const { return pieces.size(); } - unsigned getNumOutputs() const { return numOutputs; } - unsigned getNumInputs() const { return space.getNumVars(); } - MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; } + unsigned getNumVarKind(VarKind kind) const { + return space.getNumVarKind(kind); + } + unsigned getNumDomainVars() const { return space.getNumDomainVars(); } + unsigned getNumOutputs() const { return space.getNumRangeVars(); } + unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); } + + /// Remove the specified range of outputs. + void removeOutputs(unsigned start, unsigned end); + + /// Get the domain/output space of the function. The returned space is a set + /// space. + PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); } + PresburgerSpace getOutputSpace() const { return space.getDomainSpace(); } /// Return the domain of this piece-wise MultiAffineFunction. This is the /// union of the domains of all the pieces. PresburgerSet getDomain() const; - /// Return the value at the specified point and an empty optional if the - /// point does not lie in the domain. + /// Return the output of the function at the given point. Optional> valueAt(ArrayRef point) const; /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether @@ -166,11 +179,6 @@ /// value at every point in the domain. bool isEqual(const PWMAFunction &other) const; - /// Truncate the output dimensions to the first `count` dimensions. - /// - /// TODO: refactor so that this can be accomplished through removeVarRange. - void truncateOutput(unsigned count); - /// Return a function defined on the union of the domains of this and func, /// such that when only one of the functions is defined, it outputs the same /// as that function, and if both are defined, it outputs the lexmax/lexmin of @@ -178,8 +186,8 @@ /// function is not defined either. /// /// Currently this does not support PWMAFunctions which have pieces containing - /// local variables. - /// TODO: Support local variables in peices. + /// divisions. + /// TODO: Support division in pieces. PWMAFunction unionLexMin(const PWMAFunction &func); PWMAFunction unionLexMax(const PWMAFunction &func); @@ -200,19 +208,17 @@ /// /// The PresburgerSet returned by `tiebreak` should be disjoint. /// TODO: Remove this constraint of returning disjoint set. - PWMAFunction - unionFunction(const PWMAFunction &func, - llvm::function_ref - tiebreak) const; + PWMAFunction unionFunction( + const PWMAFunction &func, + llvm::function_ref tiebreak) const; + /// The space of this function. The domain variables are considered as the + /// input variables of the function. The range variables are considered as + /// the outputs. The symbols paramterize the function. PresburgerSpace space; - /// The list of pieces in this piece-wise MultiAffineFunction. - SmallVector pieces; - - /// The number of output vars. - unsigned numOutputs; + // The pieces of the PWMAFunction. + SmallVector pieces; }; } // namespace presburger diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -90,6 +90,11 @@ numLocals); } + // Get the domain/range space of this space. The returned space is a set + // space. + PresburgerSpace getDomainSpace() const; + PresburgerSpace getRangeSpace() const; + unsigned getNumDomainVars() const { return numDomain; } unsigned getNumRangeVars() const { return numRange; } unsigned getNumSetDimVars() const { return numRange; } diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -529,9 +529,9 @@ /// Represents the result of a symbolic lexicographic minimization computation. struct SymbolicLexMin { - SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs) - : lexmin(domainSpace, numOutputs), - unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {} + SymbolicLexMin(const PresburgerSpace &space) + : lexmin(space), + unboundedDomain(PresburgerSet::getEmpty(space.getDomainSpace())) {} /// This maps assignments of symbols to the corresponding lexmin. /// Takes no value when no integer sample exists for the assignment or if the diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -118,7 +118,7 @@ DivisionRepr(unsigned numVars, unsigned numDivs) : dividends(numDivs, numVars + 1), denoms(numDivs, 0) {} - DivisionRepr(unsigned numVars) : dividends(numVars + 1, 0) {} + DivisionRepr(unsigned numVars) : dividends(0, numVars + 1) {} unsigned getNumVars() const { return dividends.getNumColumns() - 1; } unsigned getNumDivs() const { return dividends.getNumRows(); } @@ -142,16 +142,25 @@ return dividends.getRow(i); } + // For a given point containing values for each variable other than the + // division variables, try to find the values for each division variable from + // their division representation. + SmallVector, 4> divValuesAt(ArrayRef point) const; + // Get the `i^th` denominator. unsigned &getDenom(unsigned i) { return denoms[i]; } unsigned getDenom(unsigned i) const { return denoms[i]; } ArrayRef getDenoms() const { return denoms; } - void setDividend(unsigned i, ArrayRef dividend) { + void setDiv(unsigned i, ArrayRef dividend, unsigned divisor) { dividends.setRow(i, dividend); + denoms[i] = divisor; } + void insertDiv(unsigned pos, ArrayRef dividend, unsigned divisor); + void insertDiv(unsigned pos, unsigned num = 1); + /// Removes duplicate divisions. On every possible duplicate division found, /// `merge(i, j)`, where `i`, `j` are current index of the duplicate /// divisions, is called and division at index `j` is merged into division at diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -238,6 +238,7 @@ getVarKindEnd(VarKind::Domain)); // Compute the symbolic lexmin of the dims and locals, with the symbols being // the actual symbols of this set. + // The resultant space of lexmin is the space of the relation itself. SymbolicLexMin result = SymbolicLexSimplex(*this, IntegerPolyhedron(PresburgerSpace::getSetSpace( @@ -248,8 +249,8 @@ // We want to return only the lexmin over the dims, so strip the locals from // the computed lexmin. - result.lexmin.truncateOutput(result.lexmin.getNumOutputs() - - getNumLocalVars()); + result.lexmin.removeOutputs(result.lexmin.getNumOutputs() - getNumLocalVars(), + result.lexmin.getNumOutputs()); return result; } diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -192,10 +192,14 @@ } void Matrix::addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) { + addToRow(targetRow, getRow(sourceRow), scale); +} + +void Matrix::addToRow(unsigned row, ArrayRef rowVec, int64_t scale) { if (scale == 0) return; for (unsigned col = 0; col < nColumns; ++col) - at(targetRow, col) += scale * at(sourceRow, col); + at(row, col) += scale * rowVec[col]; } void Matrix::addToColumn(unsigned sourceColumn, unsigned targetColumn, diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -12,11 +12,25 @@ using namespace mlir; using namespace presburger; +void MultiAffineFunction::assertIsConsistent() const { + assert(space.getNumVars() - space.getNumRangeVars() + 1 == + output.getNumColumns() && + "Inconsistent number of output columns"); + assert(space.getNumDomainVars() + space.getNumSymbolVars() == + divs.getNumNonDivs() && + "Inconsistent number of non-division variables in divs"); + assert(space.getNumRangeVars() == output.getNumRows() && + "Inconsistent number of output rows"); + assert(space.getNumLocalVars() == divs.getNumDivs() && + "Inconsistent number of divisions."); + assert(divs.hasAllReprs() && "All divisions should have a representation"); +} + // Return the result of subtracting the two given vectors pointwise. // The vectors must be of the same size. // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. -static SmallVector subtract(ArrayRef vecA, - ArrayRef vecB) { +static SmallVector subtractExprs(ArrayRef vecA, + ArrayRef vecB) { assert(vecA.size() == vecB.size() && "Cannot subtract vectors of differing lengths!"); SmallVector result; @@ -27,152 +41,135 @@ } PresburgerSet PWMAFunction::getDomain() const { - PresburgerSet domain = PresburgerSet::getEmpty(getSpace()); - for (const MultiAffineFunction &piece : pieces) - domain.unionInPlace(piece.getDomain()); + PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace()); + for (const Piece &piece : pieces) + domain.unionInPlace(piece.domain); return domain; } -Optional> +void MultiAffineFunction::print(raw_ostream &os) const { + space.print(os); + os << "Division Representation:\n"; + divs.print(os); + os << "Output:\n"; + output.print(os); +} + +SmallVector MultiAffineFunction::valueAt(ArrayRef point) const { - assert(point.size() == domainSet.getNumDimAndSymbolVars() && + assert(point.size() == getNumDomainVars() + getNumSymbolVars() && "Point has incorrect dimensionality!"); - Optional> maybeLocalValues = - getDomain().containsPointNoLocal(point); - if (!maybeLocalValues) - return {}; - - // The point lies in the domain, so we need to compute the output value. SmallVector pointHomogenous{llvm::to_vector(point)}; - // The given point didn't include the values of locals which the output is a - // function of; we have computed one possible set of values and use them - // here. The function is not allowed to have local vars that take more than - // one possible value. - pointHomogenous.append(*maybeLocalValues); + // Get the division values at this point. + SmallVector, 8> divValues = divs.divValuesAt(point); + // The given point didn't include the values of the divs which the output is a + // function of; we have computed one possible set of values and use them here. + pointHomogenous.reserve(pointHomogenous.size() + divValues.size()); + for (const Optional &divVal : divValues) + pointHomogenous.push_back(*divVal); // The matrix `output` has an affine expression in the ith row, corresponding // to the expression for the ith value in the output vector. The last column // of the matrix contains the constant term. Let v be the input point with // a 1 appended at the end. We can see that output * v gives the desired // output vector. - pointHomogenous.emplace_back(1); + pointHomogenous.push_back(1); SmallVector result = output.postMultiplyWithColumn(pointHomogenous); assert(result.size() == getNumOutputs()); return result; } -Optional> -PWMAFunction::valueAt(ArrayRef point) const { - assert(point.size() == getNumInputs() && - "Point has incorrect dimensionality!"); - for (const MultiAffineFunction &piece : pieces) - if (Optional> output = piece.valueAt(point)) - return output; - return {}; +bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { + assert(space.isCompatible(other.space) && + "Spaces should be compatible for equality check."); + return getAsRelation().isEqual(other.getAsRelation()); } -void MultiAffineFunction::print(raw_ostream &os) const { - os << "Domain:"; - domainSet.print(os); - os << "Output:\n"; - output.print(os); - os << "\n"; -} +bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, + const IntegerPolyhedron &domain) const { + assert(space.isCompatible(other.space) && + "Spaces should be compatible for equality check."); + IntegerRelation restrictedThis = getAsRelation(); + restrictedThis.intersectDomain(domain); -void MultiAffineFunction::dump() const { print(llvm::errs()); } + IntegerRelation restrictedOther = other.getAsRelation(); + restrictedOther.intersectDomain(domain); -bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { - return getDomainSpace().isCompatible(other.getDomainSpace()) && - getDomain().isEqual(other.getDomain()) && - isEqualWhereDomainsOverlap(other); + return restrictedThis.isEqual(restrictedOther); } -unsigned MultiAffineFunction::insertVar(VarKind kind, unsigned pos, - unsigned num) { - assert(kind != VarKind::Domain && "Domain has to be zero in a set"); - unsigned absolutePos = domainSet.getVarKindOffset(kind) + pos; - output.insertColumns(absolutePos, num); - return domainSet.insertVar(kind, pos, num); +bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, + const PresburgerSet &domain) const { + assert(space.isCompatible(other.space) && + "Spaces should be compatible for equality check."); + return llvm::all_of(domain.getAllDisjuncts(), + [&](const IntegerRelation &disjunct) { + return isEqual(other, IntegerPolyhedron(disjunct)); + }); } -void MultiAffineFunction::removeVarRange(VarKind kind, unsigned varStart, - unsigned varLimit) { - output.removeColumns(varStart + domainSet.getVarKindOffset(kind), - varLimit - varStart); - domainSet.removeVarRange(kind, varStart, varLimit); -} +void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) { + assert(end <= getNumOutputs() && "Invalid range"); -void MultiAffineFunction::truncateOutput(unsigned count) { - assert(count <= output.getNumRows()); - output.resizeVertically(count); -} + if (start >= end) + return; -void PWMAFunction::truncateOutput(unsigned count) { - assert(count <= numOutputs); - for (MultiAffineFunction &piece : pieces) - piece.truncateOutput(count); - numOutputs = count; + space.removeVarRange(VarKind::Range, start, end); + output.removeRows(start, end - start); } -void MultiAffineFunction::mergeLocalVars(MultiAffineFunction &other) { - // Merge output local vars of both functions without using division - // information i.e. append local vars of `other` to `this` and insert - // local vars of `this` to `other` at the start of it's local vars. - output.insertColumns(domainSet.getVarKindEnd(VarKind::Local), - other.domainSet.getNumLocalVars()); - other.output.insertColumns(other.domainSet.getVarKindOffset(VarKind::Local), - domainSet.getNumLocalVars()); +void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) { + assert(space.isCompatible(other.space) && "Functions should be compatible"); - auto merge = [this, &other](unsigned i, unsigned j) -> bool { - // Merge local at position j into local at position i in function domain. - domainSet.eliminateRedundantLocalVar(i, j); - other.domainSet.eliminateRedundantLocalVar(i, j); + unsigned nDivs = getNumDivs(); + unsigned divOffset = divs.getDivOffset(); - unsigned localOffset = domainSet.getVarKindOffset(VarKind::Local); + other.divs.insertDiv(0, nDivs); - // Merge local at position j into local at position i in output domain. - output.addToColumn(localOffset + j, localOffset + i, 1); - output.removeColumn(localOffset + j); - other.output.addToColumn(localOffset + j, localOffset + i, 1); - other.output.removeColumn(localOffset + j); + SmallVector div(other.divs.getNumVars() + 1); + for (unsigned i = 0; i < nDivs; ++i) { + // Zero fill. + std::fill(div.begin(), div.end(), 0); + // Fill div with dividend from `divs`. Do not fill the constant. + std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1, + div.begin()); + // Fill constant. + div.back() = divs.getDividend(i).back(); + other.divs.setDiv(i, div, divs.getDenom(i)); + } + + other.space.insertVar(VarKind::Local, 0, nDivs); + other.output.insertColumns(divOffset, nDivs); + + auto merge = [&](unsigned i, unsigned j) { + // We only merge from local at pos j to local at pos i, where j > i. + if (i >= j) + return false; + // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we + // do not want to merge duplicates in `this`, we ignore this call. + if (j < nDivs) + return false; + + // Merge things in space and output. + other.space.removeVarRange(VarKind::Local, j, j + 1); + other.output.addToColumn(divOffset + i, divOffset + j, 1); + other.output.removeColumn(divOffset + j); return true; }; - presburger::mergeLocalVars(domainSet, other.domainSet, merge); -} + other.divs.removeDuplicateDivs(merge); -bool MultiAffineFunction::isEqualWhereDomainsOverlap( - MultiAffineFunction other) const { - if (!getDomainSpace().isCompatible(other.getDomainSpace())) - return false; + unsigned newDivs = other.divs.getNumDivs() - nDivs; - // `commonFunc` has the same output as `this`. - MultiAffineFunction commonFunc = *this; - // After this merge, `commonFunc` and `other` have the same local vars; they - // are merged. - commonFunc.mergeLocalVars(other); - // After this, the domain of `commonFunc` will be the intersection of the - // domains of `this` and `other`. - commonFunc.domainSet.append(other.domainSet); - - // `commonDomainMatching` contains the subset of the common domain - // where the outputs of `this` and `other` match. - // - // We want to add constraints equating the outputs of `this` and `other`. - // However, `this` may have difference local vars from `other`, whereas we - // need both to have the same locals. Accordingly, we use `commonFunc.output` - // in place of `this->output`, since `commonFunc` has the same output but also - // has its locals merged. - IntegerPolyhedron commonDomainMatching = commonFunc.getDomain(); - for (unsigned row = 0, e = getNumOutputs(); row < e; ++row) - commonDomainMatching.addEquality( - subtract(commonFunc.output.getRow(row), other.output.getRow(row))); - - // If the whole common domain is a subset of commonDomainMatching, then they - // are equal and the two functions match on the whole common domain. - return commonFunc.getDomain().isSubsetOf(commonDomainMatching); + space.insertVar(VarKind::Local, nDivs, newDivs); + output.insertColumns(divOffset + nDivs, newDivs); + divs = other.divs; + + // Check consistency. + assertIsConsistent(); + other.assertIsConsistent(); } /// Two PWMAFunctions are equal if they have the same dimensionalities, @@ -188,89 +185,79 @@ // overlap, they take the same output value. If `this` and `other` have the // same domain (checked above), then this check passes iff the two functions // have the same output at every point in the domain. - for (const MultiAffineFunction &aPiece : this->pieces) - for (const MultiAffineFunction &bPiece : other.pieces) - if (!aPiece.isEqualWhereDomainsOverlap(bPiece)) - return false; - return true; + return llvm::all_of(this->pieces, [&other](const Piece &pieceA) { + return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) { + PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain); + return pieceA.output.isEqual(pieceB.output, commonDomain); + }); + }); } -void PWMAFunction::addPiece(const MultiAffineFunction &piece) { - assert(space.isCompatible(piece.getDomainSpace()) && - "Piece to be added is not compatible with this PWMAFunction!"); - assert(piece.isConsistent() && "Piece is internally inconsistent!"); - assert(this->getDomain() - .intersect(PresburgerSet(piece.getDomain())) - .isIntegerEmpty() && - "New piece's domain overlaps with that of existing pieces!"); +void PWMAFunction::addPiece(const Piece &piece) { + assert(piece.isConsistent() && "Piece should be consistent"); pieces.push_back(piece); } -void PWMAFunction::addPiece(const IntegerPolyhedron &domain, - const Matrix &output) { - addPiece(MultiAffineFunction(domain, output)); -} - -void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) { - for (const IntegerRelation &newDom : domain.getAllDisjuncts()) - addPiece(IntegerPolyhedron(newDom), output); -} - void PWMAFunction::print(raw_ostream &os) const { - os << pieces.size() << " pieces:\n"; - for (const MultiAffineFunction &piece : pieces) - piece.print(os); + space.print(os); + os << getNumPieces() << " pieces:\n"; + for (const Piece &piece : pieces) { + os << "Domain of piece:\n"; + piece.domain.print(os); + os << "Output of piece\n"; + piece.output.print(os); + } } void PWMAFunction::dump() const { print(llvm::errs()); } PWMAFunction PWMAFunction::unionFunction( const PWMAFunction &func, - llvm::function_ref - tiebreak) const { + llvm::function_ref tiebreak) const { assert(getNumOutputs() == func.getNumOutputs() && - "Number of outputs of functions should be same."); + "Ranges of functions should be same."); assert(getSpace().isCompatible(func.getSpace()) && "Space is not compatible."); // The algorithm used here is as follows: - // - Add the output of funcB for the part of the domain where both funcA and - // funcB are defined, and `tiebreak` chooses the output of funcB. - // - Add the output of funcA, where funcB is not defined or `tiebreak` chooses - // funcA over funcB. - // - Add the output of funcB, where funcA is not defined. - - // Add parts of the common domain where funcB's output is used. Also - // add all the parts where funcA's output is used, both common and non-common. - PWMAFunction result(getSpace(), getNumOutputs()); - for (const MultiAffineFunction &funcA : pieces) { - PresburgerSet dom(funcA.getDomain()); - for (const MultiAffineFunction &funcB : func.pieces) { - PresburgerSet better = tiebreak(funcB, funcA); - // Add the output of funcB, where it is better than output of funcA. + // - Add the output of pieceB for the part of the domain where both pieceA and + // pieceB are defined, and `tiebreak` chooses the output of pieceB. + // - Add the output of pieceA, where pieceB is not defined or `tiebreak` + // chooses + // pieceA over pieceB. + // - Add the output of pieceB, where pieceA is not defined. + + // Add parts of the common domain where pieceB's output is used. Also + // add all the parts where pieceA's output is used, both common and + // non-common. + PWMAFunction result(getSpace()); + for (const Piece &pieceA : pieces) { + PresburgerSet dom(pieceA.domain); + for (const Piece &pieceB : func.pieces) { + PresburgerSet better = tiebreak(pieceB, pieceA); + // Add the output of pieceB, where it is better than output of pieceA. // The disjuncts in "better" will be disjoint as tiebreak should gurantee // that. - result.addPiece(better, funcB.getOutputMatrix()); + result.addPiece({better, pieceB.output}); dom = dom.subtract(better); } - // Add output of funcA, where it is better than funcB, or funcB is not + // Add output of pieceA, where it is better than pieceB, or pieceB is not // defined. // // `dom` here is guranteed to be disjoint from already added pieces // because because the pieces added before are either: // - Subsets of the domain of other MAFs in `this`, which are guranteed // to be disjoint from `dom`, or - // - They are one of the pieces added for `funcB`, and we have been + // - They are one of the pieces added for `pieceB`, and we have been // subtracting all such pieces from `dom`, so `dom` is disjoint from those // pieces as well. - result.addPiece(dom, funcA.getOutputMatrix()); + result.addPiece({dom, pieceA.output}); } - // Add parts of funcB which are not shared with funcA. + // Add parts of pieceB which are not shared with pieceA. PresburgerSet dom = getDomain(); - for (const MultiAffineFunction &funcB : func.pieces) - result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix()); + for (const Piece &pieceB : func.pieces) + result.addPiece({pieceB.domain.subtract(dom), pieceB.output}); return result; } @@ -280,21 +267,19 @@ /// taking the lexicographically smaller output and otherwise, by taking the /// lexicographically larger output. template -static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA, - const MultiAffineFunction &mafB) { +static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, + const PWMAFunction::Piece &pieceB) { // TODO: Support local variables here. - assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) && - "Domain spaces should be compatible."); - assert(mafA.getNumOutputs() == mafB.getNumOutputs() && - "Number of outputs of both functions should be same."); - assert(mafA.getDomain().getNumLocalVars() == 0 && + assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) && + "Pieces should be compatible"); + assert(pieceA.domain.getSpace().getNumLocalVars() == 0 && "Local variables are not supported yet."); - PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals(); - const PresburgerSpace &space = mafA.getDomain().getSpace(); + PresburgerSpace compatibleSpace = pieceA.domain.getSpace(); + const PresburgerSpace &space = pieceA.domain.getSpace(); // We first create the set `result`, corresponding to the set where output - // of mafA is lexicographically larger/smaller than mafB. This is done by + // of pieceA is lexicographically larger/smaller than pieceB. This is done by // creating a PresburgerSet with the following constraints: // // (outA[0] > outB[0]) U @@ -312,14 +297,15 @@ // ... // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace); - IntegerPolyhedron levelSet(/*numReservedInequalities=*/1, - /*numReservedEqualities=*/mafA.getNumOutputs(), - /*numReservedCols=*/space.getNumVars() + 1, space); - for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) { + IntegerPolyhedron levelSet( + /*numReservedInequalities=*/1, + /*numReservedEqualities=*/pieceA.output.getNumOutputs(), + /*numReservedCols=*/space.getNumVars() + 1, space); + for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) { // Create the expression `outA - outB` for this level. - SmallVector subExpr = - subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level)); + SmallVector subExpr = subtractExprs( + pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level)); if (lexMin) { // For lexMin, we add an upper bound of -1: @@ -343,10 +329,9 @@ levelSet.addEquality(subExpr); } - // We then intersect `result` with the domain of mafA and mafB, to only + // We then intersect `result` with the domain of pieceA and pieceB, to only // tiebreak on the domain where both are defined. - result = result.intersect(PresburgerSet(mafA.getDomain())) - .intersect(PresburgerSet(mafB.getDomain())); + result = result.intersect(pieceA.domain).intersect(pieceB.domain); return result; } @@ -358,3 +343,93 @@ PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { return unionFunction(func, tiebreakLex); } + +void MultiAffineFunction::subtract(const MultiAffineFunction &other) { + assert(space.isCompatible(other.space) && + "Spaces should be compatible for subtraction."); + + MultiAffineFunction copyOther = other; + mergeDivs(copyOther); + for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) + output.addToRow(i, copyOther.getOutputExpr(i), -1); + + // Check consistency. + assertIsConsistent(); +} + +/// Adds division constraints corresponding to local variables, given a +/// relation and division representations of the local variables in the +/// relation. +static void addDivisionConstraints(IntegerRelation &rel, + const DivisionRepr &divs) { + assert(divs.hasAllReprs() && + "All divisions in divs should have a representation"); + assert(rel.getNumVars() == divs.getNumVars() && + "Relation and divs should have the same number of vars"); + assert(rel.getNumLocalVars() == divs.getNumDivs() && + "Relation and divs should have the same number of local vars"); + + for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) { + rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i), + divs.getDivOffset() + i)); + rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i), + divs.getDivOffset() + i)); + } +} + +IntegerRelation MultiAffineFunction::getAsRelation() const { + // Create a relation corressponding to the input space plus the divisions + // used in outputs. + IntegerRelation result(PresburgerSpace::getRelationSpace( + space.getNumDomainVars(), 0, space.getNumSymbolVars(), + space.getNumLocalVars())); + // Add division constraints corresponding to divisions used in outputs. + addDivisionConstraints(result, divs); + // The outputs are represented as range variables in the relation. We add + // range variables for the outputs. + result.insertVar(VarKind::Range, 0, getNumOutputs()); + + // Add equalities such that the i^th range variable is equal to the i^th + // output expression. + SmallVector eq(result.getNumCols()); + for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) { + // TODO: Add functions to get VarKind offsets in output in MAF and use them + // here. + // The output expression does not contain range variables, while the + // equality does. So, we need to copy all variables and mark all range + // variables as 0 in the equality. + ArrayRef expr = getOutputExpr(i); + // Copy domain variables in `expr` to domain variables in `eq`. + std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin()); + // Fill the range variables in `eq` as zero. + std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range), + eq.begin() + result.getVarKindEnd(VarKind::Range), 0); + // Copy remaining variables in `expr` to the remaining variables in `eq`. + std::copy(expr.begin() + getNumDomainVars(), expr.end(), + eq.begin() + result.getVarKindEnd(VarKind::Range)); + + // Set the i^th range var to -1 in `eq` to equate the output expression to + // this range var. + eq[result.getVarKindOffset(VarKind::Range) + i] = -1; + // Add the equality `rangeVar_i = output[i]`. + result.addEquality(eq); + } + + return result; +} + +void PWMAFunction::removeOutputs(unsigned start, unsigned end) { + space.removeVarRange(VarKind::Range, start, end); + for (Piece &piece : pieces) + piece.output.removeOutputs(start, end); +} + +Optional> +PWMAFunction::valueAt(ArrayRef point) const { + assert(point.size() == getNumDomainVars() + getNumSymbolVars()); + + for (const Piece &piece : pieces) + if (piece.domain.containsPoint(point)) + return piece.output.valueAt(point); + return None; +} diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -13,6 +13,15 @@ using namespace mlir; using namespace presburger; +PresburgerSpace PresburgerSpace::getDomainSpace() const { + // TODO: Preserve identifiers here. + return PresburgerSpace::getSetSpace(numDomain, numSymbols, numLocals); +} + +PresburgerSpace PresburgerSpace::getRangeSpace() const { + return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals); +} + unsigned PresburgerSpace::getNumVarKind(VarKind kind) const { if (kind == VarKind::Domain) return getNumDomainVars(); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -466,7 +466,14 @@ } output.appendExtraRow(sample); } - result.lexmin.addPiece(domainPoly, output); + + // Store the output in a MultiAffineFunction and add it the result. + PresburgerSpace funcSpace = result.lexmin.getSpace(); + funcSpace.insertVar(VarKind::Local, 0, domainPoly.getNumLocalVars()); + + result.lexmin.addPiece( + {PresburgerSet(domainPoly), + MultiAffineFunction(funcSpace, output, domainPoly.getLocalReprs())}); } Optional SymbolicLexSimplex::maybeGetAlwaysViolatedRow() { @@ -508,7 +515,10 @@ } SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { - SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol); + SymbolicLexMin result(PresburgerSpace::getRelationSpace( + /*numDomain=*/domainPoly.getNumDimVars(), + /*numRange=*/var.size() - nSymbol, + /*numSymbols=*/domainPoly.getNumSymbolVars())); /// The algorithm is more naturally expressed recursively, but we implement /// it iteratively here to avoid potential issues with stack overflows in the diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -16,6 +16,8 @@ #include "mlir/Support/MathExtras.h" #include +#include + using namespace mlir; using namespace presburger; @@ -280,10 +282,8 @@ DivisionRepr divsA = relA.getLocalReprs(); DivisionRepr divsB = relB.getLocalReprs(); - for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) { - divsA.setDividend(i, divsB.getDividend(i)); - divsA.getDenom(i) = divsB.getDenom(i); - } + for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) + divsA.setDiv(i, divsB.getDividend(i), divsB.getDenom(i)); // Remove duplicate divisions from divsA. The removing duplicate divisions // call, calls `merge` to effectively merge divisions in relA and relB. @@ -357,6 +357,55 @@ return coeffs; } +SmallVector, 4> +DivisionRepr::divValuesAt(ArrayRef point) const { + assert(point.size() == getNumNonDivs() && "Incorrect point size"); + + SmallVector, 4> divValues(getNumDivs(), None); + bool changed = true; + while (changed) { + changed = false; + for (unsigned i = 0, e = getNumDivs(); i < e; ++i) { + // If division value is found, continue; + if (divValues[i]) + continue; + + ArrayRef dividend = getDividend(i); + int64_t divVal = 0; + + // Check if we have all the division values required for this division. + unsigned j, f; + for (j = 0, f = getNumDivs(); j < f; ++j) { + if (dividend[getDivOffset() + j] == 0) + continue; + // Division value required, but not found yet. + if (!divValues[j]) + break; + divVal += dividend[getDivOffset() + j] * divValues[j].value(); + } + + // We have some division values that are still not found, but are required + // to find the value of this division. + if (j < f) + continue; + + // Fill remaining values. + divVal = std::inner_product(point.begin(), point.end(), dividend.begin(), + divVal); + // Add constant. + divVal += dividend.back(); + // Take floor division with denominator. + divVal = floorDiv(divVal, denoms[i]); + + // Set div value and continue. + divValues[i] = divVal; + changed = true; + } + } + + return divValues; +} + void DivisionRepr::removeDuplicateDivs( llvm::function_ref merge) { @@ -402,6 +451,23 @@ } } +void DivisionRepr::insertDiv(unsigned pos, ArrayRef dividend, + unsigned divisor) { + assert(pos <= getNumDivs() && "Invalid insertion position"); + assert(dividend.size() == getNumVars() + 1 && "Incorrect dividend size"); + + dividends.appendExtraRow(dividend); + denoms.insert(denoms.begin() + pos, divisor); + dividends.insertColumn(getDivOffset() + pos); +} + +void DivisionRepr::insertDiv(unsigned pos, unsigned num) { + assert(pos <= getNumDivs() && "Invalid insertion position"); + dividends.insertColumns(getDivOffset() + pos, num); + dividends.insertRows(pos, num); + denoms.insert(denoms.begin() + pos, num, 0); +} + void DivisionRepr::print(raw_ostream &os) const { os << "Dividends:\n"; dividends.print(os); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -1171,7 +1171,7 @@ ASSERT_NE(poly.getNumSymbolVars(), 0u); PWMAFunction expectedLexmin = - parsePWMAF(/*numInputs=*/poly.getNumSymbolVars(), + parsePWMAF(/*numInputs=*/0, /*numOutputs=*/poly.getNumDimVars(), expectedLexminRepr, /*numSymbols=*/poly.getNumSymbolVars()); diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -130,7 +130,7 @@ .findSymbolicIntegerLexMin(); PWMAFunction expectedLexmin = - parsePWMAF(/*numInputs=*/2, + parsePWMAF(/*numInputs=*/1, /*numOutputs=*/1, { {"(a)[b] : (a - b >= 0)", {{1, 0, 0}}}, // a diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h --- a/mlir/unittests/Analysis/Presburger/Utils.h +++ b/mlir/unittests/Analysis/Presburger/Utils.h @@ -73,14 +73,20 @@ unsigned numSymbols = 0) { static MLIRContext context; - PWMAFunction result(PresburgerSpace::getSetSpace( - /*numDims=*/numInputs - numSymbols, numSymbols), - numOutputs); + PWMAFunction result( + PresburgerSpace::getRelationSpace(numInputs, numOutputs, numSymbols)); for (const auto &pair : data) { IntegerPolyhedron domain = parsePoly(pair.first); + PresburgerSpace funcSpace = result.getSpace(); + funcSpace.insertVar(VarKind::Local, 0, domain.getNumLocalVars()); + result.addPiece( - domain, makeMatrix(numOutputs, domain.getNumVars() + 1, pair.second)); + {PresburgerSet(domain), + MultiAffineFunction( + funcSpace, + makeMatrix(numOutputs, domain.getNumVars() + 1, pair.second), + domain.getLocalReprs())}); } return result; }