diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -133,6 +133,11 @@ /// (see IntegerRelation::findIntegerSample()). bool isEqual(const IntegerRelation &other) const; + /// Return whether `this` and `other` are equal. The equality check is + /// performed in a plain manner, by comparing if all the equalities and + /// inequalities in `this` and `other` IntegerRelations are the same. + bool isPlainEqual(const IntegerRelation &other) const; + /// Return whether this is a subset of the given IntegerRelation. This is /// integer-exact and somewhat expensive, since it uses the integer emptiness /// check (see IntegerRelation::findIntegerSample()). diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -138,6 +138,10 @@ /// otherwise. bool isPlainUniverse() const; + /// Return true if the set is obviously to be equal with the other set, + /// directly comparing whether each internal disjunct is the same + bool isPlainEqual(const PresburgerRelation &set) const; + /// Return true if the set is consist of a single disjunct, without any local /// variables, false otherwise. bool isConvexNoLocals() const; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -80,6 +80,33 @@ return PresburgerRelation(*this).isEqual(PresburgerRelation(other)); } +/// The equality check is performed by comparing the number of equalities and +/// inequalities in `this` and `other` IntegerRelations, and then comparing the +/// coefficients of each corresponding equality and inequality. +bool IntegerRelation::isPlainEqual(const IntegerRelation &other) const { + if (!space.isCompatible(other.getSpace())) + return false; + if (getNumEqualities() != other.getNumEqualities()) + return false; + if (getNumInequalities() != other.getNumInequalities()) + return false; + + unsigned int cols = getNumCols(); + for (unsigned int i = 0, eqs = getNumEqualities(); i < eqs; ++i) { + for (unsigned int j = 0; j < cols; ++j) { + if (atEq(i, j) != other.atEq(i, j)) + return false; + } + } + for (unsigned int i = 0, ineqs = getNumInequalities(); i < ineqs; ++i) { + for (unsigned int j = 0; j < cols; ++j) { + if (atIneq(i, j) != other.atIneq(i, j)) + return false; + } + } + return true; +} + bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const { assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible."); return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other)); diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -53,9 +53,28 @@ /// Mutate this set, turning it into the union of this set and the given set. /// /// This is accomplished by simply adding all the disjuncts of the given set -/// to this set. +/// to this set, except for cases where it's obvious that no processing is +/// needed for. void PresburgerRelation::unionInPlace(const PresburgerRelation &set) { assert(space.isCompatible(set.getSpace()) && "Spaces should match"); + + if (isPlainEqual(set)) + return; + + if (isPlainEmpty()) { + disjuncts = set.disjuncts; + return; + } + if (set.isPlainEmpty()) + return; + + if (isPlainUniverse()) + return; + if (set.isPlainUniverse()) { + disjuncts = set.disjuncts; + return; + } + for (const IntegerRelation &disjunct : set.disjuncts) unionInPlace(disjunct); } @@ -484,6 +503,13 @@ PresburgerRelation::subtract(const PresburgerRelation &set) const { assert(space.isCompatible(set.getSpace()) && "Spaces should match"); PresburgerRelation result(getSpace()); + + // If we know that the two sets are clearly equal, we can simply return the + // empty set + if (isPlainEqual(set)) { + return result; + } + // We compute (U_i t_i) \ (U_i set_i) as U_i (t_i \ V_i set_i). for (const IntegerRelation &disjunct : disjuncts) result.unionInPlace(getSetDifference(disjunct, set)); @@ -503,6 +529,26 @@ return this->isSubsetOf(set) && set.isSubsetOf(*this); } +/// Check if this PresburgerRelation is equal to another given +/// PresburgerRelation. Two PresburgerRelations are considered equal if they +/// have the same space, and each disjunct in this PresburgerRelation is equal +/// to the corresponding disjunct in the other PresburgerRelation. +bool PresburgerRelation::isPlainEqual(const PresburgerRelation &set) const { + if (!space.isCompatible(set.getSpace())) + return false; + + if (getNumDisjuncts() != set.getNumDisjuncts()) + return false; + + // Compare each disjunct in this PresburgerRelation with the corresponding + // disjunct in the other PresburgerRelation. + for (unsigned int i = 0, n = getNumDisjuncts(); i < n; ++i) { + if (!getDisjunct(i).isPlainEqual(set.getDisjunct(i))) + return false; + } + return true; +} + /// Return true if the Presburger relation represents the universe set, false /// otherwise. It is a simple check that only check if the relation has at least /// one unconstrained disjunct, indicating the absence of constraints or