diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -97,6 +97,12 @@ /// operation returns (A intersection C) -> B. PresburgerRelation intersectDomain(const PresburgerSet &set) const; + /// Return a set corresponding to all points in the domain of the relation. + PresburgerSet getDomainSet() const; + + /// Return a set corresponding to all points in the range of the relation. + PresburgerSet getRangeSet() const; + /// Invert the relation, i.e. swap its domain and range. /// /// Formally, if `this`: A -> B then `inverse` updates `this` in-place to diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -164,6 +164,22 @@ return intersect(other); } +PresburgerSet PresburgerRelation::getDomainSet() const { + PresburgerSet result = PresburgerSet::getEmpty(space.getDomainSpace()); + for (const IntegerRelation &cs : disjuncts) { + result.unionInPlace(cs.getDomainSet()); + } + return result; +} + +PresburgerSet PresburgerRelation::getRangeSet() const { + PresburgerSet result = PresburgerSet::getEmpty(space.getRangeSpace()); + for (const IntegerRelation &cs : disjuncts) { + result.unionInPlace(cs.getRangeSet()); + } + return result; +} + void PresburgerRelation::inverse() { for (IntegerRelation &cs : disjuncts) cs.inverse(); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp @@ -191,7 +191,7 @@ } } -TEST(IntegerRelationTest, symbolicLexOpt) { +TEST(PresburgerRelationTest, symbolicLexOpt) { PresburgerRelation rel1 = parsePresburgerRelationFromPresburgerSet( {"(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, " "2 * N - x >= 0, 2 * N - y >= 0)", @@ -296,3 +296,29 @@ EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty()); EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexMax3)); } + +TEST(PresburgerRelationTest, getDomainAndRangeSet) { + PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet( + {// (x, y) -> (x + N, y - N) + "(x, y, a, b)[N] : (a >= 0, b >= 0, N - a >= 0, N - b >= 0, x - a + N " + "== 0, y - b - N == 0)", + // (x, y) -> (- y, - x) + "(x, y, a, b)[N] : (a >= 0, b >= 0, 2 * N - a >= 0, 2 * N - b >= 0, a + " + "y == 0, b + x == 0)"}, + 2); + + PresburgerSet domainSet = rel.getDomainSet(); + + PresburgerSet expectedDomainSet = parsePresburgerSet( + {"(x, y)[N] : (x + N >= 0, -x >= 0, y - N >= 0, 2 * N - y >= 0)", + "(x, y)[N] : (x + 2 * N >= 0, -x >= 0, y + 2 * N >= 0, -y >= 0)"}); + + EXPECT_TRUE(domainSet.isEqual(expectedDomainSet)); + + PresburgerSet rangeSet = rel.getRangeSet(); + + PresburgerSet expectedRangeSet = parsePresburgerSet( + {"(x, y)[N] : (x >= 0, 2 * N - x >= 0, y >= 0, 2 * N - y >= 0)"}); + + EXPECT_TRUE(rangeSet.isEqual(expectedRangeSet)); +}