diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -505,6 +505,12 @@ /// Return a set corresponding to all points in the range of the relation. IntegerPolyhedron getRangeSet() const; + /// Intersect the given `poly` with the domain. + void intersectDomain(const IntegerPolyhedron &poly); + + /// Intersect the given `poly` with the range. + void intersectRange(const IntegerPolyhedron &poly); + /// Invert the relation i.e., swap it's domain and range. /// /// Formally, let the relation `this` be R: A -> B, then this operation diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -2102,6 +2102,36 @@ return IntegerPolyhedron(std::move(copyRel)); } +void IntegerRelation::intersectDomain(const IntegerPolyhedron &poly) { + assert(getDomainSet().getSpace().isCompatible(poly.getSpace()) && + "Domain set is not compatible with poly"); + + // Treating the poly as a relation, convert it from `0 -> R` to `R -> 0`. + IntegerRelation rel = poly; + rel.convertIdKind(IdKind::Range, 0, getNumIdKind(IdKind::Domain), + IdKind::Domain); + + // Append dummy range variables to make spaces compatible. + rel.appendId(IdKind::Range, getNumIdKind(IdKind::Range)); + + // Intersect in place. + mergeLocalIds(rel); + append(rel); +} + +void IntegerRelation::intersectRange(const IntegerPolyhedron &poly) { + assert(getRangeSet().getSpace().isCompatible(poly.getSpace()) && + "Range set is not compatible with poly"); + + IntegerRelation rel = poly; + + // Append dummy domain variables to make spaces compatible. + rel.appendId(IdKind::Domain, getNumIdKind(IdKind::Domain)); + + mergeLocalIds(rel); + append(rel); +} + void IntegerRelation::inverse() { unsigned numRangeIds = getNumIdKind(IdKind::Range); convertIdKind(IdKind::Domain, 0, getIdKindEnd(IdKind::Domain), IdKind::Range); diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -57,3 +57,37 @@ EXPECT_TRUE(rel.isEqual(inverseRel)); } + +TEST(IntegerRelationTest, intersectDomainAndRange) { + IntegerRelation rel = parseRelationFromSet( + "(x, y, z)[N, M]: (y floordiv 2 - N >= 0, z floordiv 5 - M" + ">= 0, x + y + z floordiv 7 == 0)", + 1); + + { + IntegerPolyhedron poly = parsePoly("(x)[N, M] : (x >= 0, M - x - 1 >= 0)"); + + IntegerRelation expectedRel = parseRelationFromSet( + "(x, y, z)[N, M]: (y floordiv 2 - N >= 0, z floordiv 5 - M" + ">= 0, x + y + z floordiv 7 == 0, x >= 0, M - x - 1 >= 0)", + 1); + + IntegerRelation copyRel = rel; + copyRel.intersectDomain(poly); + EXPECT_TRUE(copyRel.isEqual(expectedRel)); + } + + { + IntegerPolyhedron poly = + parsePoly("(y, z)[N, M] : (y >= 0, M - y - 1 >= 0, y + z == 0)"); + + IntegerRelation expectedRel = parseRelationFromSet( + "(x, y, z)[N, M]: (y floordiv 2 - N >= 0, z floordiv 5 - M" + ">= 0, x + y + z floordiv 7 == 0, y >= 0, M - y - 1 >= 0, y + z == 0)", + 1); + + IntegerRelation copyRel = rel; + copyRel.intersectRange(poly); + EXPECT_TRUE(copyRel.isEqual(expectedRel)); + } +}