diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -49,7 +49,6 @@ enum class Kind { FlatAffineConstraints, FlatAffineValueConstraints, - MultiAffineFunction, IntegerRelation, IntegerPolyhedron, }; diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -42,54 +42,35 @@ /// /// Checking equality of two such functions is supported, as well as finding the /// value of the function at a specified point. -class MultiAffineFunction : protected IntegerPolyhedron { +class MultiAffineFunction { public: - /// We use protected inheritance to avoid inheriting the whole public - /// interface of IntegerPolyhedron. These using declarations explicitly make - /// only the relevant functions part of the public interface. - using IntegerPolyhedron::getNumDimAndSymbolIds; - using IntegerPolyhedron::getNumDimIds; - using IntegerPolyhedron::getNumIds; - using IntegerPolyhedron::getNumLocalIds; - using IntegerPolyhedron::getNumSymbolIds; - using IntegerPolyhedron::getSpace; - MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output) - : IntegerPolyhedron(domain), output(output) {} + : domainSet(domain), output(output) {} MultiAffineFunction(const Matrix &output, const PresburgerSpace &space) - : IntegerPolyhedron(space), output(output) {} - - ~MultiAffineFunction() override = default; - Kind getKind() const override { return Kind::MultiAffineFunction; } - bool classof(const IntegerRelation *rel) const { - return rel->getKind() == Kind::MultiAffineFunction; - } + : domainSet(space), output(output) {} - unsigned getNumInputs() const { return getNumDimAndSymbolIds(); } + unsigned getNumInputs() const { return domainSet.getNumDimAndSymbolIds(); } unsigned getNumOutputs() const { return output.getNumRows(); } bool isConsistent() const { - return output.getNumColumns() == getNumIds() + 1; + return output.getNumColumns() == domainSet.getNumIds() + 1; } - const IntegerPolyhedron &getDomain() const { return *this; } + const IntegerPolyhedron &getDomain() const { return domainSet; } + const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); } /// Insert `num` identifiers of the specified kind at position `pos`. /// Positions are relative to the kind of identifier. The coefficient columns /// corresponding to the added identifiers are initialized to zero. Return the /// absolute column position (i.e., not relative to the kind of identifier) /// of the first added identifier. - unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; - - /// Swap the posA^th identifier with the posB^th identifier. - void swapId(unsigned posA, unsigned posB) override; + unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1); /// Remove the specified range of ids. - void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override; - using IntegerRelation::removeIdRange; + void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); - /// Eliminate the `posB^th` local identifier, replacing every instance of it - /// with the `posA^th` local identifier. This should be used when the two - /// local variables are known to always take the same values. - void eliminateRedundantLocalId(unsigned posA, unsigned posB) override; + /// Given a MAF `other`, merges local identifiers such that both funcitons + /// have union of local ids, without changing the set of points in domain or + /// the output. + void mergeLocalIds(MultiAffineFunction &other); /// Return whether the outputs of `this` and `other` agree wherever both /// functions are defined, i.e., the outputs should be equal for all points in @@ -114,6 +95,10 @@ void dump() const; private: + /// The IntegerPolyhedron representing the domain over which the function is + /// defined. + IntegerPolyhedron domainSet; + /// The function's output is a tuple of integers, with the ith element of the /// tuple defined by the affine expression given by the ith row of this output /// matrix. diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -130,6 +130,21 @@ SmallVectorImpl &denoms, unsigned localOffset, llvm::function_ref merge); +/// Given two relations, A and B, add additional local ids to the sets such +/// that both have the union of the local ids in each set, without changing +/// the set of points that lie in A and B. +/// +/// While taking union, if a local id in any set has a division representation +/// which is a duplicate of division representation, of another local id in any +/// set, it is not added to the final union of local ids and is instead merged. +/// +/// On every possible merge, `merge(i, j)` is called. `i`, `j` are position +/// of local identifiers in both sets which are being merged. If `merge(i, j)` +/// returns true, the divisions are merged, otherwise the divisions are not +/// merged. +void mergeLocalIds(IntegerRelation &relA, IntegerRelation &relB, + llvm::function_ref merge); + /// Compute the gcd of the range. int64_t gcdRange(ArrayRef range); diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1092,36 +1092,11 @@ /// obtained, and thus these local ids are not considered for detecting /// duplicates. unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) { - assert(space.isCompatible(other.getSpace()) && - "Spaces should be compatible."); - IntegerRelation &relA = *this; IntegerRelation &relB = other; unsigned oldALocals = relA.getNumLocalIds(); - // Merge local ids of relA and relB without using division information, - // i.e. append local ids of `relB` to `relA` and insert local ids of `relA` - // to `relB` at start of its local ids. - unsigned initLocals = relA.getNumLocalIds(); - insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds()); - relB.insertId(IdKind::Local, 0, initLocals); - - // Get division representations from each rel. - std::vector> divsA, divsB; - SmallVector denomsA, denomsB; - relA.getLocalReprs(divsA, denomsA); - relB.getLocalReprs(divsB, denomsB); - - // Copy division information for relB into `divsA` and `denomsA`, so that - // these have the combined division information of both rels. Since newly - // added local variables in relA and relB have no constraints, they will not - // have any division representation. - std::copy(divsB.begin() + initLocals, divsB.end(), - divsA.begin() + initLocals); - std::copy(denomsB.begin() + initLocals, denomsB.end(), - denomsA.begin() + initLocals); - // Merge function that merges the local variables in both sets by treating // them as the same identifier. auto merge = [&relA, &relB, oldALocals](unsigned i, unsigned j) -> bool { @@ -1140,9 +1115,7 @@ return true; }; - // Merge all divisions by removing duplicate divisions. - unsigned localOffset = getIdKindOffset(IdKind::Local); - presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge); + presburger::mergeLocalIds(*this, other, merge); // Since we do not remove duplicate divisions in relA, this is guranteed to be // non-negative. diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -35,7 +35,7 @@ Optional> MultiAffineFunction::valueAt(ArrayRef point) const { - assert(point.size() == getNumDimAndSymbolIds() && + assert(point.size() == domainSet.getNumDimAndSymbolIds() && "Point has incorrect dimensionality!"); Optional> maybeLocalValues = @@ -74,7 +74,7 @@ void MultiAffineFunction::print(raw_ostream &os) const { os << "Domain:"; - IntegerPolyhedron::print(os); + domainSet.print(os); os << "Output:\n"; output.print(os); os << "\n"; @@ -83,36 +83,24 @@ void MultiAffineFunction::dump() const { print(llvm::errs()); } bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { - return space.isCompatible(other.getSpace()) && + return getDomainSpace().isCompatible(other.getDomainSpace()) && getDomain().isEqual(other.getDomain()) && isEqualWhereDomainsOverlap(other); } unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos, unsigned num) { - assert((kind != IdKind::Domain || num == 0) && - "Domain has to be zero in a set"); - unsigned absolutePos = getIdKindOffset(kind) + pos; + assert(kind != IdKind::Domain && "Domain has to be zero in a set"); + unsigned absolutePos = domainSet.getIdKindOffset(kind) + pos; output.insertColumns(absolutePos, num); - return IntegerPolyhedron::insertId(kind, pos, num); -} - -void MultiAffineFunction::swapId(unsigned posA, unsigned posB) { - output.swapColumns(posA, posB); - IntegerPolyhedron::swapId(posA, posB); + return domainSet.insertId(kind, pos, num); } void MultiAffineFunction::removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) { - output.removeColumns(idStart + getIdKindOffset(kind), idLimit - idStart); - IntegerPolyhedron::removeIdRange(kind, idStart, idLimit); -} - -void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA, - unsigned posB) { - unsigned localOffset = getIdKindOffset(IdKind::Local); - output.addToColumn(localOffset + posB, localOffset + posA, /*scale=*/1); - IntegerPolyhedron::eliminateRedundantLocalId(posA, posB); + output.removeColumns(idStart + domainSet.getIdKindOffset(kind), + idLimit - idStart); + domainSet.removeIdRange(kind, idStart, idLimit); } void MultiAffineFunction::truncateOutput(unsigned count) { @@ -127,9 +115,37 @@ numOutputs = count; } +void MultiAffineFunction::mergeLocalIds(MultiAffineFunction &other) { + // Merge output local ids of both functions without using division + // information i.e. append local ids of `other` to `this` and insert + // local ids of `this` to `other` at the start of it's local ids. + output.insertColumns(domainSet.getIdKindEnd(IdKind::Local), + other.domainSet.getNumLocalIds()); + other.output.insertColumns(other.domainSet.getIdKindOffset(IdKind::Local), + domainSet.getNumLocalIds()); + + auto merge = [this, &other](unsigned i, unsigned j) -> bool { + // Merge local at position j into local at position i in function domain. + domainSet.eliminateRedundantLocalId(i, j); + other.domainSet.eliminateRedundantLocalId(i, j); + + unsigned localOffset = domainSet.getIdKindOffset(IdKind::Local); + + // Merge local at position j into local at position i in output domain. + output.addToColumn(localOffset + j, localOffset + i, 1); + output.removeColumn(localOffset + j); + other.output.addToColumn(localOffset + j, localOffset + i, 1); + other.output.removeColumn(localOffset + j); + + return true; + }; + + presburger::mergeLocalIds(domainSet, other.domainSet, merge); +} + bool MultiAffineFunction::isEqualWhereDomainsOverlap( MultiAffineFunction other) const { - if (!space.isCompatible(other.getSpace())) + if (!getDomainSpace().isCompatible(other.getDomainSpace())) return false; // `commonFunc` has the same output as `this`. @@ -139,7 +155,7 @@ commonFunc.mergeLocalIds(other); // After this, the domain of `commonFunc` will be the intersection of the // domains of `this` and `other`. - commonFunc.IntegerPolyhedron::append(other); + commonFunc.domainSet.append(other.domainSet); // `commonDomainMatching` contains the subset of the common domain // where the outputs of `this` and `other` match. @@ -180,7 +196,7 @@ } void PWMAFunction::addPiece(const MultiAffineFunction &piece) { - assert(space.isCompatible(piece.getSpace()) && + assert(space.isCompatible(piece.getDomainSpace()) && "Piece to be added is not compatible with this PWMAFunction!"); assert(piece.isConsistent() && "Piece is internally inconsistent!"); assert(this->getDomain() diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -304,6 +304,39 @@ } } +void presburger::mergeLocalIds( + IntegerRelation &relA, IntegerRelation &relB, + llvm::function_ref merge) { + assert(relA.getSpace().isCompatible(relB.getSpace()) && + "Spaces should be compatible."); + + // Merge local ids of relA and relB without using division information, + // i.e. append local ids of `relB` to `relA` and insert local ids of `relA` + // to `relB` at start of its local ids. + unsigned initLocals = relA.getNumLocalIds(); + relA.insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds()); + relB.insertId(IdKind::Local, 0, initLocals); + + // Get division representations from each rel. + std::vector> divsA, divsB; + SmallVector denomsA, denomsB; + relA.getLocalReprs(divsA, denomsA); + relB.getLocalReprs(divsB, denomsB); + + // Copy division information for relB into `divsA` and `denomsA`, so that + // these have the combined division information of both rels. Since newly + // added local variables in relA and relB have no constraints, they will not + // have any division representation. + std::copy(divsB.begin() + initLocals, divsB.end(), + divsA.begin() + initLocals); + std::copy(denomsB.begin() + initLocals, denomsB.end(), + denomsA.begin() + initLocals); + + // Merge all divisions by removing duplicate divisions. + unsigned localOffset = relA.getIdKindOffset(IdKind::Local); + presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge); +} + int64_t presburger::gcdRange(ArrayRef range) { int64_t gcd = 0; for (int64_t elem : range) {