diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -467,15 +467,27 @@ /// Adds additional local ids to the sets such that they both have the union /// of the local ids in each set, without changing the set of points that - /// lie in `this` and `other`. The ordering of the local ids in the - /// sets may also be changed. After merging, if the `i^th` local variable in - /// one set has a known division representation, then the `i^th` local - /// variable in the other set either has the same division representation or - /// no known division representation. + /// lie in `this` and `other`. /// - /// The number of dimensions and symbol ids in `this` and `other` should - /// match. - void mergeLocalIds(IntegerRelation &other); + /// While taking union, if a local id in `other` has a division representation + /// which is a duplicate of division representation, of another local id, it + /// is not added to the final union of local ids and is instead merged. The + /// new ordering of local ids is: + /// + /// [Local ids of `this`] [Non-merged local ids of `other`] + /// + /// The relative ordering of local ids is same as before. + /// + /// After merging, if the `i^th` local variable in one set has a known + /// division representation, then the `i^th` local variable in the other set + /// either has the same division representation or no known division + /// representation. + /// + /// The spaces of both relations should be compatible. + /// + /// Returns the number of non-merged local ids of `other`, i.e. the number of + /// locals that have been added to `this`. + unsigned mergeLocalIds(IntegerRelation &other); /// Changes the partition between dimensions and symbols. Depending on the new /// symbol count, either a chunk of dimensional identifiers immediately before diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1082,19 +1082,24 @@ /// of the local ids in each set, without changing the set of points that /// lie in `this` and `other`. /// -/// To detect local ids that always take the same in both sets, each local id is +/// To detect local ids that always take the same value, each local id is /// represented as a floordiv with constant denominator in terms of other ids. -/// After extracting these divisions, local ids with the same division -/// representation are considered duplicate and are merged. It is possible that -/// division representation for some local id cannot be obtained, and thus these -/// local ids are not considered for detecting duplicates. -void IntegerRelation::mergeLocalIds(IntegerRelation &other) { +/// After extracting these divisions, local ids in `other` with the same +/// division representation as some other local id in any set are considered +/// duplicate and are merged. +/// +/// It is possible that division representation for some local id cannot be +/// obtained, and thus these local ids are not considered for detecting +/// duplicates. +unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) { assert(space.isCompatible(other.getSpace()) && "Spaces should be compatible."); IntegerRelation &relA = *this; IntegerRelation &relB = other; + unsigned oldALocals = relA.getNumLocalIds(); + // Merge local ids of relA and relB without using division information, // i.e. append local ids of `relB` to `relA` and insert local ids of `relA` // to `relB` at start of its local ids. @@ -1119,7 +1124,17 @@ // Merge function that merges the local variables in both sets by treating // them as the same identifier. - auto merge = [&relA, &relB](unsigned i, unsigned j) -> bool { + auto merge = [&relA, &relB, oldALocals](unsigned i, unsigned j) -> bool { + // We only merge from local at pos j to local at pos i, where j > i. + if (i >= j) + return false; + + // If i < oldALocals, we are trying to merge duplicate divs. Since we do not + // want to merge duplicates in A, we ignore this call. + if (j < oldALocals) + return false; + + // Merge local at pos j into local at position i. relA.eliminateRedundantLocalId(i, j); relB.eliminateRedundantLocalId(i, j); return true; @@ -1128,6 +1143,10 @@ // Merge all divisions by removing duplicate divisions. unsigned localOffset = getIdKindOffset(IdKind::Local); presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge); + + // Since we do not remove duplicate divisions in relA, this is guranteed to be + // non-negative. + return relA.getNumLocalIds() - oldALocals; } void IntegerRelation::removeDuplicateDivs() { diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -1016,6 +1016,29 @@ } } +TEST(IntegerPolyhedronTest, mergeDivisionsDuplicateInSameSet) { + // (x) : (exists y = [x + 1 / 3], z = [x + 1 / 3]: y + z >= x). + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1)); + poly1.addLocalFloorDiv({1, 1}, 3); // y = [x + 1 / 2]. + poly1.addLocalFloorDiv({1, 0, 1}, 3); // z = [x + 1 / 3]. + poly1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); + poly2.addLocalFloorDiv({1, 1}, 3); // y = [x + 1 / 3]. + poly2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. + poly2.addInequality({1, -1, -1, 0}); // y + z <= x. + + poly1.mergeLocalIds(poly2); + + // Local space should be same. + EXPECT_EQ(poly1.getNumLocalIds(), poly2.getNumLocalIds()); + + // 1 divisions should be matched. + EXPECT_EQ(poly1.getNumLocalIds(), 3u); + EXPECT_EQ(poly2.getNumLocalIds(), 3u); +} + TEST(IntegerPolyhedronTest, negativeDividends) { // (x) : (exists y = [-x + 1 / 2], z = [-x - 2 / 3]: y + z >= x). IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1));