diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -24,7 +24,7 @@ namespace mlir { namespace presburger { -/// An IntegerRelation is a PresburgerLocalSpace subject to affine constraints. +/// An IntegerRelation is a PresburgerSpace subject to affine constraints. /// Affine constraints can be inequalities or equalities in the form: /// /// Inequality: c_0*x_0 + c_1*x_1 + .... + c_{n-1}*x_{n-1} + c_n >= 0 @@ -42,7 +42,7 @@ /// /// Since IntegerRelation makes a distinction between dimensions, IdKind::Range /// and IdKind::Domain should be used to refer to dimension identifiers. -class IntegerRelation : public PresburgerLocalSpace { +class IntegerRelation : public PresburgerSpace { public: /// All derived classes of IntegerRelation. enum class Kind { @@ -59,7 +59,7 @@ unsigned numReservedEqualities, unsigned numReservedCols, unsigned numDomain, unsigned numRange, unsigned numSymbols, unsigned numLocals) - : PresburgerLocalSpace(numDomain, numRange, numSymbols, numLocals), + : PresburgerSpace(numDomain, numRange, numSymbols, numLocals), equalities(0, getNumIds() + 1, numReservedEqualities, numReservedCols), inequalities(0, getNumIds() + 1, numReservedInequalities, numReservedCols) { @@ -158,15 +158,15 @@ /// this addition can be rolled back using truncate. struct CountsSnapshot { public: - CountsSnapshot(const PresburgerLocalSpace &space, unsigned numIneqs, + CountsSnapshot(const PresburgerSpace &space, unsigned numIneqs, unsigned numEqs) : space(space), numIneqs(numIneqs), numEqs(numEqs) {} - const PresburgerLocalSpace &getSpace() const { return space; }; + const PresburgerSpace &getSpace() const { return space; }; unsigned getNumIneqs() const { return numIneqs; } unsigned getNumEqs() const { return numEqs; } private: - PresburgerLocalSpace space; + PresburgerSpace space; unsigned numIneqs, numEqs; }; CountsSnapshot getCounts() const; @@ -540,7 +540,7 @@ Matrix inequalities; }; -/// An IntegerPolyhedron is a PresburgerLocalSpace subject to affine +/// An IntegerPolyhedron is a PresburgerSpace subject to affine /// constraints. Affine constraints can be inequalities or equalities in the /// form: /// diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -52,6 +52,8 @@ using IntegerPolyhedron::getNumIds; using IntegerPolyhedron::getNumLocalIds; using IntegerPolyhedron::getNumSymbolIds; + using PresburgerSpace::isSpaceCompatible; + using PresburgerSpace::isSpaceEqual; MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output) : IntegerPolyhedron(domain), output(output) {} @@ -96,16 +98,6 @@ /// the intersection of the domains. bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const; - /// Returns whether the underlying PresburgerSpace is equal to `other`. - bool isSpaceEqual(const PresburgerSpace &other) const { - return PresburgerSpace::isEqual(other); - }; - - /// Returns whether the underlying PresburgerLocalSpace is equal to `other`. - bool isSpaceEqual(const PresburgerLocalSpace &other) const { - return PresburgerLocalSpace::isEqual(other); - }; - /// Return whether the `this` and `other` are equal. This is the case if /// they lie in the same space, i.e. have the same dimensions, and their /// domains are identical and their outputs are equal on their domain. @@ -146,7 +138,8 @@ class PWMAFunction : public PresburgerSpace { public: PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs) - : PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols), + : PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols, + /*numLocals=*/0), numOutputs(numOutputs) { assert(numOutputs >= 1 && "The function must output something!"); } diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -121,7 +121,7 @@ /// dimension and symbols. PresburgerRelation(unsigned numDomain = 0, unsigned numRange = 0, unsigned numSymbols = 0) - : PresburgerSpace(numDomain, numRange, numSymbols) {} + : PresburgerSpace(numDomain, numRange, numSymbols, /*numLocals=*/0) {} /// The list of disjuncts that this set is the union of. SmallVector integerRelations; diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -61,21 +61,23 @@ /// be implemented as a space with zero domain. IdKind::SetDim should be used /// to refer to dimensions in such spaces. /// -/// PresburgerSpace does not allow identifiers of kind Local. See -/// PresburgerLocalSpace for an extension that does allow local identifiers. +/// Compatibility of two spaces implies that number of identifiers of each kind +/// other than Locals are equal. Equality of two spaces implies that number of +/// identifiers of each kind are equal. class PresburgerSpace { - friend PresburgerLocalSpace; - public: - PresburgerSpace(unsigned numDomain, unsigned numRange, unsigned numSymbols) - : PresburgerSpace(numDomain, numRange, numSymbols, 0) {} + PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0, + unsigned numSymbols = 0, unsigned numLocals = 0) + : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols), + numLocals(numLocals) {} virtual ~PresburgerSpace() = default; unsigned getNumDomainIds() const { return numDomain; } unsigned getNumRangeIds() const { return numRange; } - unsigned getNumSymbolIds() const { return numSymbols; } unsigned getNumSetDimIds() const { return numRange; } + unsigned getNumSymbolIds() const { return numSymbols; } + unsigned getNumLocalIds() const { return numLocals; } unsigned getNumDimIds() const { return numDomain + numRange; } unsigned getNumDimAndSymbolIds() const { @@ -113,9 +115,14 @@ /// some ids at the end. `num` must be less than the current number. void truncateIdKind(IdKind kind, unsigned num); - /// Returns true if both the spaces are equal i.e. if both spaces have the - /// same number of identifiers of each kind (excluding Local Identifiers). - bool isEqual(const PresburgerSpace &other) const; + /// Returns true if both the spaces are compatible i.e. if both spaces have + /// the same number of identifiers of each kind (excluding locals). + bool isSpaceCompatible(const PresburgerSpace &other) const; + + /// Returns true if both the spaces are equal including local identifiers i.e. + /// if both spaces have the same number of identifiers of each kind (including + /// locals). + bool isSpaceEqual(const PresburgerSpace &other) const; /// Changes the partition between dimensions and symbols. Depending on the new /// symbol count, either a chunk of dimensional identifiers immediately before @@ -127,11 +134,6 @@ void dump() const; private: - PresburgerSpace(unsigned numDomain, unsigned numRange, unsigned numSymbols, - unsigned numLocals) - : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols), - numLocals(numLocals) {} - // Number of identifiers corresponding to domain identifiers. unsigned numDomain; @@ -147,32 +149,6 @@ unsigned numLocals; }; -/// Extension of PresburgerSpace supporting Local identifiers. -class PresburgerLocalSpace : public PresburgerSpace { -public: - PresburgerLocalSpace(unsigned numDomain, unsigned numRange, - unsigned numSymbols, unsigned numLocals) - : PresburgerSpace(numDomain, numRange, numSymbols, numLocals) {} - - unsigned getNumLocalIds() const { return numLocals; } - - /// Insert `num` identifiers of the specified kind at position `pos`. - /// Positions are relative to the kind of identifier. Return the absolute - /// column position (i.e., not relative to the kind of identifier) of the - /// first added identifier. - unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; - - /// Removes identifiers in the column range [idStart, idLimit). - void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override; - - /// Returns true if both the spaces are equal i.e. if both spaces have the - /// same number of identifiers of each kind. - bool isEqual(const PresburgerLocalSpace &other) const; - - void print(llvm::raw_ostream &os) const; - void dump() const; -}; - } // namespace presburger } // namespace mlir diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -38,7 +38,7 @@ } void IntegerRelation::append(const IntegerRelation &other) { - assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal."); + assert(isSpaceEqual(other) && "Spaces must be equal."); inequalities.reserveRows(inequalities.getNumRows() + other.getNumInequalities()); @@ -60,12 +60,12 @@ } bool IntegerRelation::isEqual(const IntegerRelation &other) const { - assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal."); + assert(isSpaceEqual(other) && "Spaces must be equal."); return PresburgerRelation(*this).isEqual(PresburgerRelation(other)); } bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const { - assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal."); + assert(isSpaceEqual(other) && "Spaces must be equal."); return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other)); } @@ -128,8 +128,7 @@ } IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const { - return {PresburgerLocalSpace(*this), getNumInequalities(), - getNumEqualities()}; + return {PresburgerSpace(*this), getNumInequalities(), getNumEqualities()}; } void IntegerRelation::truncateIdKind(IdKind kind, @@ -149,7 +148,7 @@ unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); - unsigned insertPos = PresburgerLocalSpace::insertId(kind, pos, num); + unsigned insertPos = PresburgerSpace::insertId(kind, pos, num); inequalities.insertColumns(insertPos, num); equalities.insertColumns(insertPos, num); return insertPos; @@ -193,7 +192,7 @@ inequalities.removeColumns(offset + idStart, idLimit - idStart); // Remove eliminated identifiers from the space. - PresburgerLocalSpace::removeIdRange(kind, idStart, idLimit); + PresburgerSpace::removeIdRange(kind, idStart, idLimit); } void IntegerRelation::removeIdRange(unsigned idStart, unsigned idLimit) { @@ -1068,7 +1067,7 @@ /// division representation for some local id cannot be obtained, and thus these /// local ids are not considered for detecting duplicates. void IntegerRelation::mergeLocalIds(IntegerRelation &other) { - assert(PresburgerSpace::isEqual(other) && "Spaces should match."); + assert(isSpaceCompatible(other) && "Spaces should be compatible."); IntegerRelation &relA = *this; IntegerRelation &relB = other; @@ -1892,7 +1891,7 @@ // lower bounds and the max of the upper bounds along each of the dimensions. LogicalResult IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { - assert(PresburgerLocalSpace::isEqual(otherCst) && "Spaces should match."); + assert(isSpaceEqual(otherCst) && "Spaces should match."); assert(getNumLocalIds() == 0 && "local ids not supported yet here"); // Get the constraints common to both systems; these will be added as is to @@ -2056,7 +2055,7 @@ } void IntegerRelation::printSpace(raw_ostream &os) const { - PresburgerLocalSpace::print(os); + PresburgerSpace::print(os); os << getNumConstraints() << " constraints\n"; } diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -84,8 +84,7 @@ void MultiAffineFunction::dump() const { print(llvm::errs()); } bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { - return PresburgerSpace::isEqual(other) && - getDomain().isEqual(other.getDomain()) && + return isSpaceCompatible(other) && getDomain().isEqual(other.getDomain()) && isEqualWhereDomainsOverlap(other); } @@ -117,7 +116,7 @@ bool MultiAffineFunction::isEqualWhereDomainsOverlap( MultiAffineFunction other) const { - if (!PresburgerSpace::isEqual(other)) + if (!isSpaceCompatible(other)) return false; // `commonFunc` has the same output as `this`. @@ -150,7 +149,7 @@ /// Two PWMAFunctions are equal if they have the same dimensionalities, /// the same domain, and take the same value at every point in the domain. bool PWMAFunction::isEqual(const PWMAFunction &other) const { - if (!PresburgerSpace::isEqual(other)) + if (!isSpaceCompatible(other)) return false; if (!this->getDomain().isEqual(other.getDomain())) @@ -168,7 +167,7 @@ } void PWMAFunction::addPiece(const MultiAffineFunction &piece) { - assert(piece.isSpaceEqual(*this) && + assert(piece.isSpaceCompatible(*this) && "Piece to be added is not compatible with this PWMAFunction!"); assert(piece.isConsistent() && "Piece is internally inconsistent!"); assert(this->getDomain() diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -37,7 +37,7 @@ /// Mutate this set, turning it into the union of this set and the given /// IntegerRelation. void PresburgerRelation::unionInPlace(const IntegerRelation &disjunct) { - assert(PresburgerSpace::isEqual(disjunct) && "Spaces should match"); + assert(isSpaceCompatible(disjunct) && "Spaces should match"); integerRelations.push_back(disjunct); } @@ -46,7 +46,7 @@ /// This is accomplished by simply adding all the disjuncts of the given set /// to this set. void PresburgerRelation::unionInPlace(const PresburgerRelation &set) { - assert(PresburgerSpace::isEqual(set) && "Spaces should match"); + assert(isSpaceCompatible(set) && "Spaces should match"); for (const IntegerRelation &disjunct : set.integerRelations) unionInPlace(disjunct); } @@ -54,7 +54,7 @@ /// Return the union of this set and the given set. PresburgerRelation PresburgerRelation::unionSet(const PresburgerRelation &set) const { - assert(PresburgerSpace::isEqual(set) && "Spaces should match"); + assert(isSpaceCompatible(set) && "Spaces should match"); PresburgerRelation result = *this; result.unionInPlace(set); return result; @@ -91,7 +91,7 @@ // variables of both. PresburgerRelation PresburgerRelation::intersect(const PresburgerRelation &set) const { - assert(PresburgerSpace::isEqual(set) && "Spaces should match"); + assert(isSpaceCompatible(set) && "Spaces should match"); PresburgerRelation result(getNumDomainIds(), getNumRangeIds(), getNumSymbolIds()); @@ -281,7 +281,7 @@ /// returning from that function. static PresburgerRelation getSetDifference(IntegerRelation disjunct, const PresburgerRelation &set) { - assert(disjunct.PresburgerSpace::isEqual(set) && "Spaces should match"); + assert(disjunct.isSpaceCompatible(set) && "Spaces should match"); if (disjunct.isEmptyByGCDTest()) return PresburgerRelation::getEmpty(disjunct.getNumDomainIds(), disjunct.getNumRangeIds(), @@ -307,7 +307,7 @@ /// return `this \ set`. PresburgerRelation PresburgerRelation::subtract(const PresburgerRelation &set) const { - assert(PresburgerSpace::isEqual(set) && "Spaces should match"); + assert(isSpaceCompatible(set) && "Spaces should match"); PresburgerRelation result(getNumDomainIds(), getNumRangeIds(), getNumSymbolIds()); // We compute (U_i t_i) \ (U_i set_i) as U_i (t_i \ V_i set_i). @@ -325,7 +325,7 @@ /// Two sets are equal iff they are subsets of each other. bool PresburgerRelation::isEqual(const PresburgerRelation &set) const { - assert(PresburgerSpace::isEqual(set) && "Spaces should match"); + assert(isSpaceCompatible(set) && "Spaces should match"); return this->isSubsetOf(set) && set.isSubsetOf(*this); } diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -68,7 +68,7 @@ else if (kind == IdKind::Symbol) numSymbols += num; else - llvm_unreachable("PresburgerSpace does not support local identifiers!"); + numLocals += num; return absolutePos; } @@ -88,7 +88,7 @@ else if (kind == IdKind::Symbol) numSymbols -= numIdsEliminated; else - llvm_unreachable("PresburgerSpace does not support local identifiers!"); + numLocals -= numIdsEliminated; } void PresburgerSpace::truncateIdKind(IdKind kind, unsigned num) { @@ -97,37 +97,14 @@ removeIdRange(kind, num, curNum); } -unsigned PresburgerLocalSpace::insertId(IdKind kind, unsigned pos, - unsigned num) { - if (kind == IdKind::Local) { - numLocals += num; - return getIdKindOffset(IdKind::Local) + pos; - } - return PresburgerSpace::insertId(kind, pos, num); -} - -void PresburgerLocalSpace::removeIdRange(IdKind kind, unsigned idStart, - unsigned idLimit) { - assert(idLimit <= getNumIdKind(kind) && "invalid id limit"); - - if (idStart >= idLimit) - return; - - if (kind == IdKind::Local) - numLocals -= idLimit - idStart; - else - PresburgerSpace::removeIdRange(kind, idStart, idLimit); -} - -bool PresburgerSpace::isEqual(const PresburgerSpace &other) const { +bool PresburgerSpace::isSpaceCompatible(const PresburgerSpace &other) const { return getNumDomainIds() == other.getNumDomainIds() && getNumRangeIds() == other.getNumRangeIds() && getNumSymbolIds() == other.getNumSymbolIds(); } -bool PresburgerLocalSpace::isEqual(const PresburgerLocalSpace &other) const { - return PresburgerSpace::isEqual(other) && - getNumLocalIds() == other.getNumLocalIds(); +bool PresburgerSpace::isSpaceEqual(const PresburgerSpace &other) const { + return isSpaceCompatible(other) && getNumLocalIds() == other.getNumLocalIds(); } void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) { @@ -140,14 +117,8 @@ void PresburgerSpace::print(llvm::raw_ostream &os) const { os << "Domain: " << getNumDomainIds() << ", " << "Range: " << getNumRangeIds() << ", " - << "Symbols: " << getNumSymbolIds() << "\n"; + << "Symbols: " << getNumSymbolIds() << ", " + << "Locals: " << getNumLocalIds() << "\n"; } void PresburgerSpace::dump() const { print(llvm::errs()); } - -void PresburgerLocalSpace::print(llvm::raw_ostream &os) const { - PresburgerSpace::print(os); - os << "Locals: " << getNumLocalIds() << "\n"; -} - -void PresburgerLocalSpace::dump() const { print(llvm::errs()); }