diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1938,16 +1938,51 @@ fac.removeId(pos2); } +/// Normalize the `dividend` and the `denominator`. For ex: if the dividend and +/// denominator are [2,0,4] and 4 respectively, they get normalized to [1,0,2] +/// and 2. +static void normalizeDivisions(std::vector> &divs, + SmallVectorImpl &denoms) { + + SmallVector gcds(denoms.size(), 1); + + // Find gcd of each row of dividend and denominator. + for (size_t i = 0, m = divs.size(); i < m; i++) { + if (denoms[i] == 0) + continue; + int64_t result = + llvm::greatestCommonDivisor(divs[i][0], int64_t(denoms[i])); + for (size_t j = 1, n = divs[i].size(); j < n; j++) { + result = llvm::greatestCommonDivisor(divs[i][j], result); + if (result == 1) + break; + } + gcds[i] = result; + } + + // Normalize the dividend and denominator respectively. + for (size_t i = 0; i < divs.size(); i++) { + if (denoms[i] == 0 || gcds[i] == 1) { + continue; + } + denoms[i] = denoms[i] / gcds[i]; + for (size_t j = 0; j < divs[i].size(); j++) { + divs[i][j] = divs[i][j] / gcds[i]; + } + } +} + /// Adds additional local ids to the sets such that they both have the union /// of the local ids in each set, without changing the set of points that /// lie in `this` and `other`. /// -/// To detect local ids that always take the same in both sets, each local id is -/// represented as a floordiv with constant denominator in terms of other ids. -/// After extracting these divisions, local ids with the same division -/// representation are considered duplicate and are merged. It is possible that -/// division representation for some local id cannot be obtained, and thus these -/// local ids are not considered for detecting duplicates. +/// To detect local ids that always take the same in both sets, each local +/// id is represented as a floordiv with constant denominator in terms of +/// other ids. After extracting these divisions, local ids with the same +/// division representation are considered duplicate and are merged. It is +/// possible that division representation for some local id cannot be +/// obtained, and thus these local ids are not considered for detecting +/// duplicates. void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) { assert(getNumDimIds() == other.getNumDimIds() && "Number of dimension ids should match"); @@ -1960,6 +1995,7 @@ // Merge local ids of facA and facB without using division information, // i.e. append local ids of `facB` to `facA` and insert local ids of `facA` // to `facB` at start of its local ids. + unsigned initLocals = facA.getNumLocalIds(); insertLocalId(facA.getNumLocalIds(), facB.getNumLocalIds()); facB.insertLocalId(0, initLocals); @@ -1970,6 +2006,9 @@ facA.getLocalReprs(divsA, denomsA); facB.getLocalReprs(divsB, denomsB); + normalizeDivisions(divsA, denomsA); + normalizeDivisions(divsB, denomsB); + // Copy division information for facB into `divsA` and `denomsA`, so that // these have the combined division information of both FACs. Since newly // added local variables in facA and facB have no constraints, they will not diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -813,6 +813,29 @@ EXPECT_EQ(fac1.getNumLocalIds(), 2u); EXPECT_EQ(fac2.getNumLocalIds(), 2u); } + + { + // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). + FlatAffineConstraints fac1(1, 0, 1); + fac1.addLocalFloorDiv({3, 0, 0}, 6); // y = [3x / 6] -> [x/2]. + fac1.addEquality({1, 0, -3, 0}); // x = 3z. + fac1.addInequality({1, 1, 0, 1}); // x + y + 1 >= 0. + + // (x) : (exists y = [x / 2], z : x = 5y). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addEquality({1, -5, 0}); // x = 5y. + fac2.appendLocalId(); // Add local id z. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division should be matched + 2 unmatched local ids. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } } TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) { @@ -863,6 +886,30 @@ EXPECT_EQ(fac1.getNumLocalIds(), 3u); EXPECT_EQ(fac2.getNumLocalIds(), 3u); } + { + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + // Normalization test. + fac1.addLocalFloorDiv({2, 0}, 4); // y = [x / 2]. + fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + // Normalization test. + fac2.addLocalFloorDiv({3, 3, 0}, 9); // z = [x + y / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } } TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) { @@ -884,6 +931,30 @@ // Local space should be same. EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } + { + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2]. + // Normalization test. + fac1.addLocalFloorDiv({3, 0, 6}, 9); // z = [x + 2 / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + // Normalization test. + fac2.addLocalFloorDiv({2, 2}, 4); // y = [x + 1 / 2]. + fac2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + // 2 divisions should be matched. EXPECT_EQ(fac1.getNumLocalIds(), 2u); EXPECT_EQ(fac2.getNumLocalIds(), 2u);