diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -505,6 +505,12 @@ /// Return a set corresponding to all points in the range of the relation. IntegerPolyhedron getRangeSet() const; + /// Invert the relation i.e., swap it's domain and range. + /// + /// Formally, let the relation `this` be R: A -> B, then this operation + /// modifies R to be B -> A. + void inverse(); + void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -2102,6 +2102,12 @@ return IntegerPolyhedron(std::move(copyRel)); } +void IntegerRelation::inverse() { + unsigned numRangeIds = getNumIdKind(IdKind::Range); + convertIdKind(IdKind::Domain, 0, getIdKindEnd(IdKind::Domain), IdKind::Range); + convertIdKind(IdKind::Range, 0, numRangeIds, IdKind::Domain); +} + void IntegerRelation::printSpace(raw_ostream &os) const { space.print(os); os << getNumConstraints() << " constraints\n"; diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -41,3 +41,19 @@ EXPECT_TRUE(rangeSet.isEqual(expectedRangeSet)); } + +TEST(IntegerRelationTest, inverse) { + IntegerRelation rel = + parseRelationFromSet("(x, y, z)[N, M] : (z - x - y == 0, x >= 0, N - x " + ">= 0, y >= 0, M - y >= 0)", + 2); + + IntegerRelation inverseRel = + parseRelationFromSet("(z, x, y)[N, M] : (x >= 0, N - x >= 0, y >= 0, M " + "- y >= 0, x + y - z == 0)", + 1); + + rel.inverse(); + + EXPECT_TRUE(rel.isEqual(inverseRel)); +}