diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -459,10 +459,16 @@ void removeDuplicateDivs(); /// Converts identifiers of kind srcKind in the range [idStart, idLimit) to - /// variables of kind dstKind and placed after all the other variables of kind - /// dstKind. The internal ordering among the moved variables is preserved. + /// variables of kind dstKind. If `pos` is given, the variables are placed at + /// position `pos` of dstKind, otherwise they are placed after all the other + /// variables of kind dstKind. The internal ordering among the moved variables + /// is preserved. void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit, - IdKind dstKind); + IdKind dstKind, unsigned pos); + void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit, + IdKind dstKind) { + convertIdKind(srcKind, idStart, idLimit, dstKind, getNumIdKind(dstKind)); + } void convertToLocal(IdKind kind, unsigned idStart, unsigned idLimit) { convertIdKind(kind, idStart, idLimit, IdKind::Local); } @@ -514,6 +520,20 @@ /// Get inverse of this relation. void inverse(); + /// Given a relation `rel`, apply the relation to the domain of this relation. + /// + /// Formally, let the relation `this` be R1: A -> B, and `rel` be + /// R2: A -> C, returns a relation R3: C -> B, such that, a point (c, b) \in + /// R3, iff \exists a, (a, b) \in R1, (a, c) \in R2. + void applyDomain(const IntegerRelation &rel); + + /// Given a relation `rel`, apply the relation to the range of this relation. + /// + /// Formally, let the relation `this` be R1: A -> B, and `rel` be + /// R2: B -> C, returns a relation R3: A -> C, such that, a point (a, c) \in + /// R3, iff \exists b, (a, b) \in R1, (b, c) \in R2. + void applyRange(const IntegerRelation &rel); + void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1184,16 +1184,16 @@ } void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart, - unsigned idLimit, IdKind dstKind) { + unsigned idLimit, IdKind dstKind, + unsigned pos) { assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range"); if (idStart >= idLimit) return; // Append new local variables corresponding to the dimensions to be converted. - unsigned newIdsBegin = getIdKindEnd(dstKind); unsigned convertCount = idLimit - idStart; - appendId(dstKind, convertCount); + unsigned newIdsBegin = insertId(dstKind, pos, convertCount); // Swap the new local variables with dimensions. // @@ -2142,6 +2142,38 @@ convertIdKind(IdKind::Range, 0, numRangeIds, IdKind::Domain); } +void IntegerRelation::applyRange(const IntegerRelation &rel) { + assert(getRangeSet().getSpace().isCompatible(rel.getDomainSet().getSpace()) && + "Range of `this` should be compatible with Domain of `rel`"); + + IntegerRelation copyRel = rel; + + // Let relation `this` be R1: A -> B, and `rel` be R2: B -> C. + // We convert R1 to A -> (B -> C) then intersect the range of R1 with R2. + // After this, we get R1: A -> C, by projecting out B. + unsigned numBIds = getNumRangeIds(); + + // Convert R1 from A -> B to A -> (B -> C). + appendId(IdKind::Range, copyRel.getNumRangeIds()); + + // Convert R2 to a set for intersection. + // TODO: Using nested spaces here would help, since we could directly + // intersect the range with another relation. + copyRel.convertIdKind(IdKind::Domain, 0, numBIds, IdKind::Range, 0); + + // Intersect R2 to range of R1. + intersectRange(IntegerPolyhedron(copyRel)); + + // Project out B in R1. + convertIdKind(IdKind::Range, 0, numBIds, IdKind::Local); +} + +void IntegerRelation::applyDomain(const IntegerRelation &rel) { + inverse(); + applyRange(rel); + inverse(); +} + void IntegerRelation::printSpace(raw_ostream &os) const { space.print(os); os << getNumConstraints() << " constraints\n"; diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -91,3 +91,34 @@ EXPECT_TRUE(copyRel.isEqual(expectedRel)); } } + +TEST(IntegerRelationTest, applyDomainAndRange) { + + { + IntegerRelation map1 = parseRelationFromSet( + "(x, y, a, b)[N] : (a - x - N == 0, b - y + N == 0)", 2); + IntegerRelation map2 = + parseRelationFromSet("(x, y, a)[N] : (a - x - y == 0)", 2); + + map1.applyRange(map2); + + IntegerRelation map3 = + parseRelationFromSet("(x, y, a)[N] : (a - x - y == 0)", 2); + + EXPECT_TRUE(map1.isEqual(map3)); + } + + { + IntegerRelation map1 = parseRelationFromSet( + "(x, y, a, b)[N] : (a - x + N == 0, b - y - N == 0)", 2); + IntegerRelation map2 = + parseRelationFromSet("(x, y, a, b)[N] : (a - N == 0, b - N == 0)", 2); + + IntegerRelation map3 = + parseRelationFromSet("(x, y, a, b)[N] : (x - N == 0, y - N == 0)", 2); + + map1.applyDomain(map2); + + EXPECT_TRUE(map1.isEqual(map3)); + } +}