diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -441,10 +441,16 @@ /// variables. void convertDimToLocal(unsigned dimStart, unsigned dimLimit); - /// Merge local ids of `this` and `other`. This is done by appending local ids - /// of `other` to `this` and inserting local ids of `this` to `other` at start - /// of its local ids. Number of dimension and symbol ids should match in - /// `this` and `other`. + /// Adds additional local ids to the sets such that they both have the union + /// of the local ids in each set, without changing the set of points that + /// lie in `this` and `other`. The ordering of the local ids in the + /// sets may also be changed. After merging, if the `i^th` local variable in + /// one set has a known division representation, then the `i^th` local + /// variable in the other set either has the same division representation or + /// no known division representation. + /// + /// The number of dimensions and symbol ids in `this` and `other` should + /// match. void mergeLocalIds(FlatAffineConstraints &other); /// Removes all equalities and inequalities. @@ -819,8 +825,8 @@ /// constraint systems are updated so that they have the union of all /// identifiers, with `this`'s original identifiers appearing first followed /// by any of `other`'s identifiers that didn't appear in `this`. Local - /// identifiers of each system are by design separate/local and are placed - /// one after other (`this`'s followed by `other`'s). + /// identifiers in `other` that have the same division representation as local + /// identifiers in `this` are merged into one. // E.g.: Input: `this` has (%i, %j) [%M, %N] // `other` has (%k, %j) [%P, %N, %M] // Output: both `this`, `other` have (%i, %j, %k) [%M, %N, %P] diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -493,8 +493,8 @@ /// dimension-wise and symbol-wise unique; both constraint systems are updated /// so that they have the union of all identifiers, with A's original /// identifiers appearing first followed by any of B's identifiers that didn't -/// appear in A. Local identifiers of each system are by design separate/local -/// and are placed one after other (A's followed by B's). +/// appear in A. Local identifiers in B that have the same division +/// representation as local identifiers in A are merged into one. // E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M]) // Output: both A, B have (%i, %j, %k) [%M, %N, %P] static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a, @@ -1918,18 +1918,108 @@ equalities.resizeVertically(pos); } -/// Merge local ids of `this` and `other`. This is done by appending local ids -/// of `other` to `this` and inserting local ids of `this` to `other` at start -/// of its local ids. Number of dimension and symbol ids should match in -/// `this` and `other`. +/// Eliminate `pos2^th` local identifier, replacing its every instance with +/// `pos1^th` local identifier. This function is intended to be used to remove +/// redundancy when local variables at position `pos1` and `pos2` are restricted +/// to have the same value. +static void eliminateRedundantLocalId(FlatAffineConstraints &fac, unsigned pos1, + unsigned pos2) { + + assert(pos1 < fac.getNumLocalIds() && "Invalid local id position"); + assert(pos2 < fac.getNumLocalIds() && "Invalid local id position"); + + unsigned localOffset = fac.getNumDimAndSymbolIds(); + pos1 += localOffset; + pos2 += localOffset; + for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i) + fac.atIneq(i, pos1) += fac.atIneq(i, pos2); + for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i) + fac.atEq(i, pos1) += fac.atEq(i, pos2); + fac.removeId(pos2); +} + +/// Adds additional local ids to the sets such that they both have the union +/// of the local ids in each set, without changing the set of points that +/// lie in `this` and `other`. +/// +/// To detect local ids that always take the same in both sets, each local id is +/// represented as a floordiv with constant denominator in terms of other ids. +/// After extracting these divisions, local ids with the same division +/// representation are considered duplicate and are merged. It is possible that +/// division representation for some local id cannot be obtained, and thus these +/// local ids are not considered for detecting duplicates. void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) { assert(getNumDimIds() == other.getNumDimIds() && "Number of dimension ids should match"); assert(getNumSymbolIds() == other.getNumSymbolIds() && "Number of symbol ids should match"); - unsigned initLocals = getNumLocalIds(); - insertLocalId(getNumLocalIds(), other.getNumLocalIds()); - other.insertLocalId(0, initLocals); + + FlatAffineConstraints &facA = *this; + FlatAffineConstraints &facB = other; + + // Merge local ids of facA and facB without using division information, + // i.e. append local ids of `facB` to `facA` and insert local ids of `facA` + // to `facB` at start of its local ids. + unsigned initLocals = facA.getNumLocalIds(); + insertLocalId(facA.getNumLocalIds(), facB.getNumLocalIds()); + facB.insertLocalId(0, initLocals); + + // Get division representations from each FAC. + std::vector> divsA, divsB; + SmallVector denomsA, denomsB; + facA.getLocalReprs(divsA, denomsA); + facB.getLocalReprs(divsB, denomsB); + + // Copy division information for facB into `divsA` and `denomsA`, so that + // these have the combined division information of both FACs. Since newly + // added local variables in facA and facB have no constraints, they will not + // have any division representation. + std::copy(divsB.begin() + initLocals, divsB.end(), + divsA.begin() + initLocals); + std::copy(denomsB.begin() + initLocals, denomsB.end(), + denomsA.begin() + initLocals); + + // Find and merge duplicate divisions. + // TODO: Add division normalization to support divisions that differ by + // a constant. + // TODO: Add division ordering such that a division representation for local + // identifier at position `i` only depends on local identifiers at position < + // `i`. This would make sure that all divisions depending on other local + // variables that can be merged, are merged. + unsigned localOffset = getIdKindOffset(IdKind::Local); + for (unsigned i = 0; i < divsA.size(); ++i) { + // Check if a division representation exists for the `i^th` local id. + if (denomsA[i] == 0) + continue; + // Check if a division exists which is a duplicate of the division at `i`. + for (unsigned j = i + 1; j < divsA.size(); ++j) { + // Check if a division representation exists for the `j^th` local id. + if (denomsA[j] == 0) + continue; + // Check if the denominators match. + if (denomsA[i] != denomsA[j]) + continue; + // Check if the representations are equal. + if (divsA[i] != divsA[j]) + continue; + + // Merge divisions at position `j` into division at position `i`. + eliminateRedundantLocalId(facA, i, j); + eliminateRedundantLocalId(facB, i, j); + for (unsigned k = 0, g = divsA.size(); k < g; ++k) { + SmallVector &div = divsA[k]; + if (denomsA[k] != 0) { + div[localOffset + i] += div[localOffset + j]; + div.erase(div.begin() + localOffset + j); + } + } + + divsA.erase(divsA.begin() + j); + denomsA.erase(denomsA.begin() + j); + // Since `j` can never be zero, we do not need to worry about overflows. + --j; + } + } } /// Removes local variables using equalities. Each equality is checked if it diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -809,4 +809,127 @@ EXPECT_TRUE(fac3.isEmpty()); } +TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) { + { + // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). + FlatAffineConstraints fac1(1, 0, 1); + fac1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2]. + fac1.addEquality({1, 0, -3, 0}); // x = 3y. + fac1.addInequality({1, 1, 0, 1}); // x + z + 1 >= 0. + + // (x) : (exists y = [x / 2], z : x = 5y). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addEquality({1, -5, 0}); // x = 5y. + fac2.appendLocalId(); // Add local id z. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division should be matched + 2 unmatched local ids. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } + + { + // (x) : (exists z = [x / 5], y = [x / 2] : x = 3y). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 5); // z = [x / 5]. + fac1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2]. + fac1.addEquality({1, 0, -3, 0}); // x = 3y. + + // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addLocalFloorDiv({1, 0, 0}, 5); // z = [x / 5]. + fac2.addEquality({1, 0, -5, 0}); // x = 5z. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } +} + +TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) { + { + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } + + { + // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac1.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5]. + fac1.addInequality({-1, 1, 1, 0, 0}); // y + z >= x. + + // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac2.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5]. + fac2.addInequality({1, -1, -1, 0, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 3 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } +} + +TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) { + { + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2]. + fac1.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2]. + fac2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } +} + } // namespace mlir