diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -95,6 +95,10 @@ /// intersection with no simplification of any sort attempted. void append(const IntegerRelation &other); + /// Return the intersection of the two sets. + /// If there are locals, they will be merged. + IntegerRelation intersect(IntegerRelation other) const; + /// Return whether `this` and `other` are equal. This is integer-exact /// and somewhat expensive, since it uses the integer emptiness check /// (see IntegerRelation::findIntegerSample()). diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -52,6 +52,13 @@ } } +IntegerRelation IntegerRelation::intersect(IntegerRelation other) const { + IntegerRelation result = *this; + result.mergeLocalIds(other); + result.append(other); + return result; +} + bool IntegerRelation::isEqual(const IntegerRelation &other) const { assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal."); return PresburgerRelation(*this).isEqual(PresburgerRelation(other)); diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -97,11 +97,9 @@ getNumSymbolIds()); for (const IntegerRelation &csA : integerRelations) { for (const IntegerRelation &csB : set.integerRelations) { - IntegerRelation csACopy = csA, csBCopy = csB; - csACopy.mergeLocalIds(csBCopy); - csACopy.append(csBCopy); - if (!csACopy.isEmpty()) - result.unionInPlace(csACopy); + IntegerRelation intersection = csA.intersect(csB); + if (!intersection.isEmpty()) + result.unionInPlace(intersection); } } return result;