diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -64,6 +64,8 @@ /// exceeds that of some disjunct, an assert failure will occur. void setSpace(const PresburgerSpace &oSpace); + void insertVarInPlace(VarKind kind, unsigned pos, unsigned num = 1); + /// Return a reference to the list of disjuncts. ArrayRef getAllDisjuncts() const; @@ -83,6 +85,18 @@ /// Return the intersection of this set and the given set. PresburgerRelation intersect(const PresburgerRelation &set) const; + /// Intersect the given `set` with the range in-place. + /// + /// Formally, let the relation `this` be R: A -> B and `set` is C, then this + /// operation modifies R to be A -> (B intersection C). + PresburgerRelation intersectRange(PresburgerSet &set); + + /// Intersect the given `set` with the domain in-place. + /// + /// Formally, let the relation `this` be R: A -> B and `set` is C, then this + /// operation modifies R to be (A intersection C) -> B. + PresburgerRelation intersectDomain(const PresburgerSet &set); + /// Invert the relation, i.e. swap its domain and range. /// /// Formally, if `this`: A -> B then `inverse` updates `this` in-place to diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -30,6 +30,13 @@ disjunct.setSpaceExceptLocals(space); } +void PresburgerRelation::insertVarInPlace(VarKind kind, unsigned pos, + unsigned num) { + for (IntegerRelation &cs : disjuncts) + cs.insertVar(kind, pos, num); + space.insertVar(kind, pos, num); +} + unsigned PresburgerRelation::getNumDisjuncts() const { return disjuncts.size(); } @@ -117,6 +124,26 @@ return result; } +PresburgerRelation PresburgerRelation::intersectRange(PresburgerSet &set) { + assert(space.getRangeSpace().isCompatible(set.getSpace()) && + "Range of `this` must be compatible with range of `set`"); + + PresburgerRelation other = set; + other.insertVarInPlace(VarKind::Domain, 0, getNumDomainVars()); + return intersect(other); +} + +PresburgerRelation +PresburgerRelation::intersectDomain(const PresburgerSet &set) { + assert(space.getDomainSpace().isCompatible(set.getSpace()) && + "Domain of `this` must be compatible with range of `set`"); + + PresburgerRelation other = set; + other.insertVarInPlace(VarKind::Domain, 0, getNumDomainVars()); + other.inverse(); + return intersect(other); +} + void PresburgerRelation::inverse() { for (IntegerRelation &cs : disjuncts) cs.inverse(); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp @@ -31,6 +31,73 @@ return result; } +TEST(PresburgerRelationTest, intersectDomainAndRange) { + PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet( + {// (x, y) -> (x + N, y - N) + "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0)", + // (x, y) -> (x + y, x - y) + "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0)", + // (x, y) -> (x - y, y - x)} + "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0)"}, + 2); + + { + PresburgerSet set = + parsePresburgerSet({// (2x, x) + "(a, b)[N] : (a - 2 * b == 0)", + // (x, -x) + "(a, b)[N] : (a + b == 0)", + // (N, N) + "(a, b)[N] : (a - N == 0, b - N == 0)"}); + + PresburgerRelation expectedRel = parsePresburgerRelationFromPresburgerSet( + {"(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, x - 2 * y == 0)", + "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, x + y == 0)", + "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, x - N == 0, y - N " + "== 0)", + "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, x - 2 * y == 0)", + "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, x + y == 0)", + "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, x - N == 0, y - N " + "== 0)", + "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, x - 2 * y == 0)", + "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, x + y == 0)", + "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, x - N == 0, y - N " + "== 0)"}, + 2); + + PresburgerRelation computedRel = rel.intersectDomain(set); + EXPECT_TRUE(computedRel.isEqual(expectedRel)); + } + + { + PresburgerSet set = + parsePresburgerSet({// (2x, x) + "(a, b)[N] : (a - 2 * b == 0)", + // (x, -x) + "(a, b)[N] : (a + b == 0)", + // (N, N) + "(a, b)[N] : (a - N == 0, b - N == 0)"}); + + PresburgerRelation expectedRel = parsePresburgerRelationFromPresburgerSet( + {"(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, a - 2 * b == 0)", + "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, a + b == 0)", + "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, a - N == 0, b - N " + "== 0)", + "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, a - 2 * b == 0)", + "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, a + b == 0)", + "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, a - N == 0, b - N " + "== 0)", + "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, a - 2 * b == 0)", + "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, a + b == 0)", + "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, a - N == 0, b - N " + "== 0)"}, + 2); + + PresburgerRelation computedRel = rel.intersectRange(set); + EXPECT_TRUE(computedRel.isEqual(expectedRel)); + } +} + TEST(PresburgerRelationTest, applyDomainAndRange) { { PresburgerRelation map1 = parsePresburgerRelationFromPresburgerSet(