diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -83,6 +83,27 @@ /// Return the intersection of this set and the given set. PresburgerRelation intersect(const PresburgerRelation &set) const; + /// Invert the relation, i.e. swap its domain and range + /// + /// Formally, if `this`: A -> B then `inverse` updates `this` in-place to + /// `this`: B -> A + void inverse(); + + /// Compose `this` relation with the given relation `rel` in-place. + /// + /// Formally, if `this`: A -> B, and `rel`: B -> C, then this function updates + /// `this` to `result`: A -> C where a point (a, c) belongs to `result` + /// iff there exists b such that (a, b) is in `this` and, (b, c) is in rel. + void compose(const PresburgerRelation &rel); + + /// Apply the domain of given relation `rel` to `this` relation + /// + /// Formally, R1.applyDomain(R2) = R2.inverse().compose(R1) + void applyDomain(const PresburgerRelation &rel); + + /// Same as compose, provided for uniformity with applyDomain + void applyRange(const PresburgerRelation &rel); + /// Return true if the set contains the given point, and false otherwise. bool containsPoint(ArrayRef point) const; bool containsPoint(ArrayRef point) const { diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/PresburgerRelation.h" +#include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Utils.h" #include "llvm/ADT/STLExtras.h" @@ -108,6 +109,50 @@ return result; } +void PresburgerRelation::inverse() { + PresburgerRelation result = + PresburgerRelation::getEmpty(PresburgerSpace::getRelationSpace( + getNumRangeVars(), getNumDomainVars(), getNumSymbolVars())); + for (IntegerRelation &cs : disjuncts) { + cs.inverse(); + result.unionInPlace(cs); + } + *this = result; +} + +void PresburgerRelation::compose(const PresburgerRelation &rel) { + assert(getSpace().getRangeSpace().isCompatible( + rel.getSpace().getDomainSpace()) && + "Range of `this` should be compatible with domain of `rel`"); + + PresburgerRelation result = + PresburgerRelation::getEmpty(PresburgerSpace::getRelationSpace( + getNumDomainVars(), rel.getNumRangeVars(), getNumSymbolVars())); + for (const IntegerRelation &csA : disjuncts) { + for (const IntegerRelation &csB : rel.disjuncts) { + IntegerRelation composition = csA; + composition.compose(csB); + if (!composition.isEmpty()) + result.unionInPlace(composition); + } + } + *this = result; +} + +void PresburgerRelation::applyDomain(const PresburgerRelation &rel) { + assert(getSpace().getDomainSpace().isCompatible( + rel.getSpace().getDomainSpace()) && + "Domain of `this` should be compatible with domain of `rel`"); + + inverse(); + compose(rel); + inverse(); +} + +void PresburgerRelation::applyRange(const PresburgerRelation &rel) { + compose(rel); +} + /// Return the coefficients of the ineq in `rel` specified by `idx`. /// `idx` can refer not only to an actual inequality of `rel`, but also /// to either of the inequalities that make up an equality in `rel`. diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -7,6 +7,7 @@ Parser.h ParserTest.cpp PresburgerSetTest.cpp + PresburgerRelationTest.cpp PresburgerSpaceTest.cpp PWMAFunctionTest.cpp SimplexTest.cpp diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp @@ -0,0 +1,103 @@ +//===- PresburgerRelationTest.cpp - Tests for IntegerRelation class +//----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Analysis/Presburger/PresburgerRelation.h" +#include "Parser.h" + +#include +#include +#include + +using namespace mlir; +using namespace presburger; + +static PresburgerRelation +parsePresburgerRelationFromPresburgerSet(ArrayRef strs, + unsigned numDomain) { + assert(!strs.empty() && "strs should not be empty"); + + IntegerRelation rel = parseIntegerPolyhedron(strs[0]); + rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain); + PresburgerRelation result(rel); + for (unsigned i = 1, e = strs.size(); i < e; ++i) { + rel = parseIntegerPolyhedron(strs[i]); + rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain); + result.unionInPlace(rel); + } + return result; +} + +TEST(PresburgerRelationTest, applyDomainAndRange) { + { + PresburgerRelation map1 = parsePresburgerRelationFromPresburgerSet( + {// (x, y) <-> (x + N, y - N) + "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0)", + // (x, y) <-> (y, x) + "(x, y, a, b)[N] : (a - y == 0, b - x == 0)", + // (x, y) <-> (x + y, x - y) + "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0)"}, + 2); + PresburgerRelation map2 = parsePresburgerRelationFromPresburgerSet( + {// (x, y) <-> (x + y) + "(x, y, r)[N] : (r - x - y == 0)", + // (x, y) <-> (N) + "(x, y, r)[N] : (r - N == 0)", + // (x, y) <-> (y - x) + "(x, y, r)[N] : (r + x - y == 0)"}, + 2); + + map1.applyRange(map2); + + PresburgerRelation map3 = parsePresburgerRelationFromPresburgerSet( + { + // (x, y) <-> (x + y) + "(x, y, r)[N] : (r - x - y == 0)", + // (x, y) <-> (N) + "(x, y, r)[N] : (r - N == 0)", + // (x, y) <-> (y - x - 2N) + "(x, y, r)[N] : (r - y + x + 2 * N == 0)", + // (x, y) <-> (x - y) + "(x, y, r)[N] : (r - x + y == 0)", + // (x, y) <-> (2x) + "(x, y, r)[N] : (r - 2 * x == 0)", + // (x, y) <-> (-2y) + "(x, y, r)[N] : (r + 2 * y == 0)", + }, + 2); + + EXPECT_TRUE(map1.isEqual(map3)); + } + + { + PresburgerRelation map1 = parsePresburgerRelationFromPresburgerSet( + {// (x, y) <-> (y, x) + "(x, y, a, b)[N] : (y - a == 0, x - b == 0)", + // (x, y) <-> (x + N, y - N) + "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0)"}, + 2); + PresburgerRelation map2 = parsePresburgerRelationFromPresburgerSet( + {// (x, y) <-> (x - y) + "(x, y, r)[N] : (x - y - r == 0)", + // (x, y) <-> N + "(x, y, r)[N] : (N - r == 0)"}, + 2); + + map1.applyDomain(map2); + + PresburgerRelation map3 = parsePresburgerRelationFromPresburgerSet( + {// (y - x) <-> (x, y) + "(r, x, y)[N] : (y - x - r == 0)", + // (x - y - 2N) <-> (x, y) + "(r, x, y)[N] : (x - y - 2 * N - r == 0)", + // (x, y) <-> N + "(r, x, y)[N] : (N - r == 0)"}, + 1); + + EXPECT_TRUE(map1.isEqual(map3)); + } +}