diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -73,8 +73,6 @@ } const IntegerPolyhedron &getDomain() const { return *this; } - bool hasCompatibleDimensions(const MultiAffineFunction &f) const; - /// Insert `num` identifiers of the specified kind at position `pos`. /// Positions are relative to the kind of identifier. The coefficient columns /// corresponding to the added identifiers are initialized to zero. Return the @@ -98,6 +96,16 @@ /// the intersection of the domains. bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const; + /// Returns whether the underlying PresburgerSpace is equal to `other`. + bool isSpaceEqual(const PresburgerSpace &other) const { + return PresburgerSpace::isEqual(other); + }; + + /// Returns whether the underlying PresburgerLocalSpace is equal to `other`. + bool isSpaceEqual(const PresburgerLocalSpace &other) const { + return PresburgerLocalSpace::isEqual(other); + }; + /// Return whether the `this` and `other` are equal. This is the case if /// they lie in the same space, i.e. have the same dimensions, and their /// domains are identical and their outputs are equal on their domain. @@ -139,7 +147,7 @@ /// Support is provided to compare equality of two such functions as well as /// finding the value of the function at a point. Note that local ids in the /// piece are not supported for the latter. -class PWMAFunction : PresburgerSpace { +class PWMAFunction : public PresburgerSpace { public: PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs) : PresburgerSpace(numDims, numSymbols), numOutputs(numOutputs) { @@ -159,12 +167,6 @@ /// union of the domains of all the pieces. PresburgerSet getDomain() const; - /// Check whether the `this` and the given function have compatible - /// dimensions, i.e., the same number of dimension inputs, symbol inputs, and - /// outputs. - bool hasCompatibleDimensions(const MultiAffineFunction &f) const; - bool hasCompatibleDimensions(const PWMAFunction &f) const; - /// Return the value at the specified point and an empty optional if the /// point does not lie in the domain. /// diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -113,6 +113,10 @@ /// Removes identifiers in the column range [idStart, idLimit). virtual void removeIdRange(unsigned idStart, unsigned idLimit); + /// Returns true if both the spaces are equal i.e. if both spaces have the + /// same number of identifiers of each kind (excluding Local Identifiers). + bool isEqual(const PresburgerSpace &other) const; + /// Changes the partition between dimensions and symbols. Depending on the new /// symbol count, either a chunk of dimensional identifiers immediately before /// the split become symbols, or some of the symbols immediately after the @@ -193,6 +197,10 @@ PresburgerLocalSpace(unsigned numDims, unsigned numSymbols, unsigned numLocals) : PresburgerSpace(Set, /*numDomain=*/0, numDims, numSymbols, numLocals) {} + + /// Returns true if both the spaces are equal i.e. if both spaces have the + /// same number of identifiers of each kind. + bool isEqual(const PresburgerLocalSpace &other) const; }; } // namespace presburger diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -49,9 +49,7 @@ } void IntegerPolyhedron::append(const IntegerPolyhedron &other) { - assert(other.getNumCols() == getNumCols()); - assert(other.getNumDimIds() == getNumDimIds()); - assert(other.getNumSymbolIds() == getNumSymbolIds()); + assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal."); inequalities.reserveRows(inequalities.getNumRows() + other.getNumInequalities()); @@ -1037,10 +1035,7 @@ /// division representation for some local id cannot be obtained, and thus these /// local ids are not considered for detecting duplicates. void IntegerPolyhedron::mergeLocalIds(IntegerPolyhedron &other) { - assert(getNumDimIds() == other.getNumDimIds() && - "Number of dimension ids should match"); - assert(getNumSymbolIds() == other.getNumSymbolIds() && - "Number of symbol ids should match"); + assert(PresburgerSpace::isEqual(other) && "Spaces should match."); IntegerPolyhedron &polyA = *this; IntegerPolyhedron &polyB = other; @@ -1856,8 +1851,7 @@ // lower bounds and the max of the upper bounds along each of the dimensions. LogicalResult IntegerPolyhedron::unionBoundingBox(const IntegerPolyhedron &otherCst) { - assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch"); - assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here"); + assert(PresburgerLocalSpace::isEqual(otherCst) && "Spaces should match."); assert(getNumLocalIds() == 0 && "local ids not supported yet here"); // Get the constraints common to both systems; these will be added as is to diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -77,7 +77,7 @@ void MultiAffineFunction::dump() const { print(llvm::errs()); } bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { - return hasCompatibleDimensions(other) && + return PresburgerSpace::isEqual(other) && getDomain().isEqual(other.getDomain()) && isEqualWhereDomainsOverlap(other); } @@ -107,7 +107,7 @@ bool MultiAffineFunction::isEqualWhereDomainsOverlap( MultiAffineFunction other) const { - if (!hasCompatibleDimensions(other)) + if (!PresburgerSpace::isEqual(other)) return false; // `commonFunc` has the same output as `this`. @@ -140,7 +140,7 @@ /// Two PWMAFunctions are equal if they have the same dimensionalities, /// the same domain, and take the same value at every point in the domain. bool PWMAFunction::isEqual(const PWMAFunction &other) const { - if (!hasCompatibleDimensions(other)) + if (!PresburgerSpace::isEqual(other)) return false; if (!this->getDomain().isEqual(other.getDomain())) @@ -158,7 +158,7 @@ } void PWMAFunction::addPiece(const MultiAffineFunction &piece) { - assert(hasCompatibleDimensions(piece) && + assert(piece.isSpaceEqual(*this) && "Piece to be added is not compatible with this PWMAFunction!"); assert(piece.isConsistent() && "Piece is internally inconsistent!"); assert(this->getDomain() @@ -178,22 +178,3 @@ for (const MultiAffineFunction &piece : pieces) piece.print(os); } - -/// The hasCompatibleDimensions functions don't check the number of local ids; -/// functions are still compatible if they have differing number of locals. -bool MultiAffineFunction::hasCompatibleDimensions( - const MultiAffineFunction &f) const { - return getNumDimIds() == f.getNumDimIds() && - getNumSymbolIds() == f.getNumSymbolIds() && - getNumOutputs() == f.getNumOutputs(); -} -bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const { - return getNumDimIds() == f.getNumDimIds() && - getNumSymbolIds() == f.getNumSymbolIds() && - getNumOutputs() == f.getNumOutputs(); -} -bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const { - return getNumDimIds() == f.getNumDimIds() && - getNumSymbolIds() == f.getNumSymbolIds() && - getNumOutputs() == f.getNumOutputs(); -} diff --git a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp @@ -35,31 +35,10 @@ return integerPolyhedrons[index]; } -/// Assert that the IntegerPolyhedron and PresburgerSet live in -/// compatible spaces. -static void assertDimensionsCompatible(const IntegerPolyhedron &poly, - const PresburgerSet &set) { - assert(poly.getNumDimIds() == set.getNumDimIds() && - "Number of dimensions of the IntegerPolyhedron and PresburgerSet" - "do not match!"); - assert(poly.getNumSymbolIds() == set.getNumSymbolIds() && - "Number of symbols of the IntegerPolyhedron and PresburgerSet" - "do not match!"); -} - -/// Assert that the two PresburgerSets live in compatible spaces. -static void assertDimensionsCompatible(const PresburgerSet &setA, - const PresburgerSet &setB) { - assert(setA.getNumDimIds() == setB.getNumDimIds() && - "Number of dimensions of the PresburgerSets do not match!"); - assert(setA.getNumSymbolIds() == setB.getNumSymbolIds() && - "Number of symbols of the PresburgerSets do not match!"); -} - /// Mutate this set, turning it into the union of this set and the given /// IntegerPolyhedron. void PresburgerSet::unionPolyInPlace(const IntegerPolyhedron &poly) { - assertDimensionsCompatible(poly, *this); + assert(PresburgerSpace::isEqual(poly) && "Spaces should match"); integerPolyhedrons.push_back(poly); } @@ -68,14 +47,14 @@ /// This is accomplished by simply adding all the Poly of the given set to this /// set. void PresburgerSet::unionSetInPlace(const PresburgerSet &set) { - assertDimensionsCompatible(set, *this); + assert(PresburgerSpace::isEqual(set) && "Spaces should match"); for (const IntegerPolyhedron &poly : set.integerPolyhedrons) unionPolyInPlace(poly); } /// Return the union of this set and the given set. PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const { - assertDimensionsCompatible(set, *this); + assert(PresburgerSpace::isEqual(set) && "Spaces should match"); PresburgerSet result = *this; result.unionSetInPlace(set); return result; @@ -108,7 +87,7 @@ // If S_i or T_j have local variables, then S_i and T_j contains the local // variables of both. PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const { - assertDimensionsCompatible(set, *this); + assert(PresburgerSpace::isEqual(set) && "Spaces should match"); PresburgerSet result(getNumDimIds(), getNumSymbolIds()); for (const IntegerPolyhedron &csA : integerPolyhedrons) { @@ -326,7 +305,7 @@ /// from that function. PresburgerSet PresburgerSet::getSetDifference(IntegerPolyhedron poly, const PresburgerSet &set) { - assertDimensionsCompatible(poly, set); + assert(poly.PresburgerSpace::isEqual(set) && "Spaces should match"); if (poly.isEmptyByGCDTest()) return PresburgerSet::getEmptySet(poly.getNumDimIds(), poly.getNumSymbolIds()); @@ -346,7 +325,7 @@ /// Return the result of subtract the given set from this set, i.e., /// return `this \ set`. PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const { - assertDimensionsCompatible(set, *this); + assert(PresburgerSpace::isEqual(set) && "Spaces should match"); PresburgerSet result(getNumDimIds(), getNumSymbolIds()); // We compute (U_i t_i) \ (U_i set_i) as U_i (t_i \ V_i set_i). for (const IntegerPolyhedron &poly : integerPolyhedrons) @@ -363,7 +342,7 @@ /// Two sets are equal iff they are subsets of each other. bool PresburgerSet::isEqual(const PresburgerSet &set) const { - assertDimensionsCompatible(set, *this); + assert(PresburgerSpace::isEqual(set) && "Spaces should match"); return this->isSubsetOf(set) && set.isSubsetOf(*this); } diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -154,6 +154,17 @@ numLocals -= numLocalsEliminated; } +bool PresburgerSpace::isEqual(const PresburgerSpace &other) const { + return getNumDomainIds() == other.getNumDomainIds() && + getNumRangeIds() == other.getNumRangeIds() && + getNumSymbolIds() == other.getNumSymbolIds(); +} + +bool PresburgerLocalSpace::isEqual(const PresburgerLocalSpace &other) const { + return PresburgerSpace::isEqual(other) && + getNumLocalIds() == other.getNumLocalIds(); +} + void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) { assert(newSymbolCount <= getNumDimAndSymbolIds() && "invalid separation position");