diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -459,10 +459,16 @@ void removeDuplicateDivs(); /// Converts identifiers of kind srcKind in the range [idStart, idLimit) to - /// variables of kind dstKind and placed after all the other variables of kind - /// dstKind. The internal ordering among the moved variables is preserved. + /// variables of kind dstKind. If `pos` is given, the variables are placed at + /// position `pos` of dstKind, otherwise they are placed after all the other + /// variables of kind dstKind. The internal ordering among the moved variables + /// is preserved. void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit, - IdKind dstKind); + IdKind dstKind, unsigned pos); + void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit, + IdKind dstKind) { + convertIdKind(srcKind, idStart, idLimit, dstKind, getNumIdKind(dstKind)); + } void convertToLocal(IdKind kind, unsigned idStart, unsigned idLimit) { convertIdKind(kind, idStart, idLimit, IdKind::Local); } @@ -523,6 +529,34 @@ /// modifies R to be B -> A. void inverse(); + /// Let the relation `this` be R1, and the relation `rel` be R2. Modifies R1 + /// to be the composition of R1 and R2: R1(R2). + /// + /// Formally, let R1: A -> B, and R2: B -> C, returns a relation R3: A -> C, + /// such that, a point (a, c) belongs to R3, iff there exists b, such that + /// (a , b) to R1, (b, c) to R2. + void compose(const IntegerRelation &rel); + + /// Given a relation `rel`, apply the relation to the domain of this relation. + /// + /// For example, given a relations: + /// R1: (i -> j, k) : (0 <= i < 100, 0 <= j < 10, 0 <= k < 10) + /// R2: (i -> i') : (i' = i floordiv 4) + /// R1.applyDomain(R2): (i -> j, k) : (0 <= i < 25, 0 <= j < 10, 0 <= k < 10) + /// + /// Formally, R1.applyDomain(R2) = R2.inverse().compose(R1) + void applyDomain(const IntegerRelation &rel); + + /// Given a relation `rel`, apply the relation to the range of this relation. + /// + /// For example, given a relations: + /// R1: (i -> i') : (i = i' floordiv 4) + /// R2: (i -> j, k) : (0 <= i < 100, 0 <= j < 10, 0 <= k < 10) + /// R1.applyRange(R2): (i -> j, k) : (0 <= i < 25, 0 <= j < 10, 0 <= k < 10) + /// + /// Formally, R1.applyRange(R2) = R1.compose(R2) + void applyRange(const IntegerRelation &rel); + void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1184,16 +1184,16 @@ } void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart, - unsigned idLimit, IdKind dstKind) { + unsigned idLimit, IdKind dstKind, + unsigned pos) { assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range"); if (idStart >= idLimit) return; // Append new local variables corresponding to the dimensions to be converted. - unsigned newIdsBegin = getIdKindEnd(dstKind); unsigned convertCount = idLimit - idStart; - appendId(dstKind, convertCount); + unsigned newIdsBegin = insertId(dstKind, pos, convertCount); // Swap the new local variables with dimensions. // @@ -2137,6 +2137,40 @@ convertIdKind(IdKind::Range, 0, numRangeIds, IdKind::Domain); } +void IntegerRelation::compose(const IntegerRelation &rel) { + assert(getRangeSet().getSpace().isCompatible(rel.getDomainSet().getSpace()) && + "Range of `this` should be compatible with Domain of `rel`"); + + IntegerRelation copyRel = rel; + + // Let relation `this` be R1: A -> B, and `rel` be R2: B -> C. + // We convert R1 to A -> (B -> C) then intersect the range of R1 with R2. + // After this, we get R1: A -> C, by projecting out B. + unsigned numBIds = getNumRangeIds(); + + // Convert R1 from A -> B to A -> (B -> C). + appendId(IdKind::Range, copyRel.getNumRangeIds()); + + // Convert R2 to a set for intersection. + // TODO: Using nested spaces here would help, since we could directly + // intersect the range with another relation. + copyRel.convertIdKind(IdKind::Domain, 0, numBIds, IdKind::Range, 0); + + // Intersect R2 to range of R1. + intersectRange(IntegerPolyhedron(copyRel)); + + // Project out B in R1. + convertIdKind(IdKind::Range, 0, numBIds, IdKind::Local); +} + +void IntegerRelation::applyDomain(const IntegerRelation &rel) { + inverse(); + compose(rel); + inverse(); +} + +void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); } + void IntegerRelation::printSpace(raw_ostream &os) const { space.print(os); os << getNumConstraints() << " constraints\n"; diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -91,3 +91,34 @@ EXPECT_TRUE(copyRel.isEqual(expectedRel)); } } + +TEST(IntegerRelationTest, applyDomainAndRange) { + + { + IntegerRelation map1 = parseRelationFromSet( + "(x, y, a, b)[N] : (a - x - N == 0, b - y + N == 0)", 2); + IntegerRelation map2 = + parseRelationFromSet("(x, y, a)[N] : (a - x - y == 0)", 2); + + map1.applyRange(map2); + + IntegerRelation map3 = + parseRelationFromSet("(x, y, a)[N] : (a - x - y == 0)", 2); + + EXPECT_TRUE(map1.isEqual(map3)); + } + + { + IntegerRelation map1 = parseRelationFromSet( + "(x, y, a, b)[N] : (a - x + N == 0, b - y - N == 0)", 2); + IntegerRelation map2 = + parseRelationFromSet("(x, y, a, b)[N] : (a - N == 0, b - N == 0)", 2); + + IntegerRelation map3 = + parseRelationFromSet("(x, y, a, b)[N] : (x - N == 0, y - N == 0)", 2); + + map1.applyDomain(map2); + + EXPECT_TRUE(map1.isEqual(map3)); + } +}