diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -459,10 +459,16 @@ void removeDuplicateDivs(); /// Converts identifiers of kind srcKind in the range [idStart, idLimit) to - /// variables of kind dstKind and placed after all the other variables of kind - /// dstKind. The internal ordering among the moved variables is preserved. + /// variables of kind dstKind. If `pos` is given, the variables are placed at + /// position `pos` of dstKind, otherwise they are placed after all the other + /// variables of kind dstKind. The internal ordering among the moved variables + /// is preserved. void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit, - IdKind dstKind); + IdKind dstKind, unsigned pos); + void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit, + IdKind dstKind) { + convertIdKind(srcKind, idStart, idLimit, dstKind, getNumIdKind(dstKind)); + } void convertToLocal(IdKind kind, unsigned idStart, unsigned idLimit) { convertIdKind(kind, idStart, idLimit, IdKind::Local); } @@ -523,6 +529,32 @@ /// modifies R to be B -> A. void inverse(); + /// Let the relation `this` be R1, and the relation `rel` be R2. Modifies R1 + /// to be the composition of R1 and R2: R1;R2. + /// + /// Formally, if R1: A -> B, and R2: B -> C, then this function returns a + /// relation R3: A -> C such that a point (a, c) belongs to R3 iff there + /// exists b such that (a, b) is in R1 and, (b, c) is in R2. + void compose(const IntegerRelation &rel); + + /// Given a relation `rel`, apply the relation to the domain of this relation. + /// + /// R1: i -> j : (0 <= i < 2, j = i) + /// R2: i -> k : (k = i floordiv 2) + /// R3: k -> j : (0 <= k < 1, 2k <= j <= 2k + 1) + /// + /// R1 = {(0, 0), (1, 1)}. R2 maps both 0 and 1 to 0. + /// So R3 = {(0, 0), (0, 1)}. + /// + /// Formally, R1.applyDomain(R2) = R2.inverse().compose(R1). + void applyDomain(const IntegerRelation &rel); + + /// Given a relation `rel`, apply the relation to the range of this relation. + /// + /// Formally, R1.applyRange(R2) is the same as R1.compose(R2) but we provide + /// this for uniformity with `applyDomain`. + void applyRange(const IntegerRelation &rel); + void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1184,16 +1184,16 @@ } void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart, - unsigned idLimit, IdKind dstKind) { + unsigned idLimit, IdKind dstKind, + unsigned pos) { assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range"); if (idStart >= idLimit) return; // Append new local variables corresponding to the dimensions to be converted. - unsigned newIdsBegin = getIdKindEnd(dstKind); unsigned convertCount = idLimit - idStart; - appendId(dstKind, convertCount); + unsigned newIdsBegin = insertId(dstKind, pos, convertCount); // Swap the new local variables with dimensions. // @@ -2137,6 +2137,40 @@ convertIdKind(IdKind::Range, 0, numRangeIds, IdKind::Domain); } +void IntegerRelation::compose(const IntegerRelation &rel) { + assert(getRangeSet().getSpace().isCompatible(rel.getDomainSet().getSpace()) && + "Range of `this` should be compatible with Domain of `rel`"); + + IntegerRelation copyRel = rel; + + // Let relation `this` be R1: A -> B, and `rel` be R2: B -> C. + // We convert R1 to A -> (B X C), and R2 to B X C then intersect the range of + // R1 with R2. After this, we get R1: A -> C, by projecting out B. + // TODO: Using nested spaces here would help, since we could directly + // intersect the range with another relation. + unsigned numBIds = getNumRangeIds(); + + // Convert R1 from A -> B to A -> (B X C). + appendId(IdKind::Range, copyRel.getNumRangeIds()); + + // Convert R2 to B X C. + copyRel.convertIdKind(IdKind::Domain, 0, numBIds, IdKind::Range, 0); + + // Intersect R2 to range of R1. + intersectRange(IntegerPolyhedron(copyRel)); + + // Project out B in R1. + convertIdKind(IdKind::Range, 0, numBIds, IdKind::Local); +} + +void IntegerRelation::applyDomain(const IntegerRelation &rel) { + inverse(); + compose(rel); + inverse(); +} + +void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); } + void IntegerRelation::printSpace(raw_ostream &os) const { space.print(os); os << getNumConstraints() << " constraints\n"; diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -91,3 +91,34 @@ EXPECT_TRUE(copyRel.isEqual(expectedRel)); } } + +TEST(IntegerRelationTest, applyDomainAndRange) { + + { + IntegerRelation map1 = parseRelationFromSet( + "(x, y, a, b)[N] : (a - x - N == 0, b - y + N == 0)", 2); + IntegerRelation map2 = + parseRelationFromSet("(x, y, a)[N] : (a - x - y == 0)", 2); + + map1.applyRange(map2); + + IntegerRelation map3 = + parseRelationFromSet("(x, y, a)[N] : (a - x - y == 0)", 2); + + EXPECT_TRUE(map1.isEqual(map3)); + } + + { + IntegerRelation map1 = parseRelationFromSet( + "(x, y, a, b)[N] : (a - x + N == 0, b - y - N == 0)", 2); + IntegerRelation map2 = + parseRelationFromSet("(x, y, a, b)[N] : (a - N == 0, b - N == 0)", 2); + + IntegerRelation map3 = + parseRelationFromSet("(x, y, a, b)[N] : (x - N == 0, y - N == 0)", 2); + + map1.applyDomain(map2); + + EXPECT_TRUE(map1.isEqual(map3)); + } +}