diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1352,6 +1352,31 @@ return true; } +/// Normalize the `expr` and the `denominator` by removing factors that are +/// common between them. For ex: if the dividend and denominator are [2,0,4] and +/// 4 respectively, they get normalized to [1,0,2] and 2. +static void normalizeDivisions(SmallVectorImpl &expr, + unsigned &denominator) { + if (denominator == 0) + return; + int64_t gcd = expr.size() > 0 + ? llvm::greatestCommonDivisor(expr[0], int64_t(denominator)) + : 1; + + // The constant term is ignored while taking gcd since + // floor((a + m.f(x))/(m.d)) can be replaced by floor((floor(a/m) + f(x))/d). + for (size_t i = 1, m = expr.size() - 1; i < m; i++) { + gcd = llvm::greatestCommonDivisor(expr[i], gcd); + if (gcd == 1) + return; + } + + // Normalize the expr and denominator respectively. + std::transform(expr.begin(), expr.end(), expr.begin(), + [&gcd](int64_t &n) { return floor(n / gcd); }); + denominator /= gcd; +} + /// Check if the pos^th identifier can be represented as a division using upper /// bound inequality at position `ubIneq` and lower bound inequality at position /// `lbIneq`. @@ -1375,7 +1400,8 @@ /// q = (i + j + 1) floordiv 4 /// /// If successful, `expr` is set to dividend of the division and `divisor` is -/// set to the denominator of the division. +/// set to the denominator of the division. Finally, the `expr` and `divisor` +/// are normalized by taking the common factor between them. static LogicalResult getDivRepr(const FlatAffineConstraints &cst, unsigned pos, unsigned ubIneq, unsigned lbIneq, SmallVector &expr, @@ -1419,6 +1445,7 @@ // Set divisor. divisor = denominator; + normalizeDivisions(expr, divisor); return success(); } @@ -1938,6 +1965,7 @@ fac.removeId(pos2); } + /// Adds additional local ids to the sets such that they both have the union /// of the local ids in each set, without changing the set of points that /// lie in `this` and `other`. diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -691,9 +691,9 @@ fac.addLocalFloorDiv({0, 0, 0, 0, 10}, 30); fac.addLocalFloorDiv({0, 0, 0, 0, 0, 99}, 101); - std::vector> divisions = {{0, 0, 0, 0, 0, 0, 10}, - {0, 0, 0, 0, 0, 0, 99}}; - SmallVector denoms = {30, 101}; + std::vector> divisions = {{0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0}}; + SmallVector denoms = {1, 1}; // Check if floordivs with constant numerator can be computed. checkDivisionRepresentation(fac, divisions, denoms); @@ -813,6 +813,29 @@ EXPECT_EQ(fac1.getNumLocalIds(), 2u); EXPECT_EQ(fac2.getNumLocalIds(), 2u); } + + { + // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). + FlatAffineConstraints fac1(1, 0, 1); + fac1.addLocalFloorDiv({3, 0, 0}, 6); // y = [3x / 6] -> [x/2]. + fac1.addEquality({1, 0, -3, 0}); // x = 3z. + fac1.addInequality({1, 1, 0, 1}); // x + y + 1 >= 0. + + // (x) : (exists y = [x / 2], z : x = 5y). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addEquality({1, -5, 0}); // x = 5y. + fac2.appendLocalId(); // Add local id z. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division should be matched + 2 unmatched local ids. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } } TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) { @@ -863,6 +886,30 @@ EXPECT_EQ(fac1.getNumLocalIds(), 3u); EXPECT_EQ(fac2.getNumLocalIds(), 3u); } + { + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + // Normalization test. + fac1.addLocalFloorDiv({2, 0}, 4); // y = [x / 2]. + fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + // Normalization test. + fac2.addLocalFloorDiv({3, 3, 0}, 9); // z = [x + y / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } } TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) { @@ -884,6 +931,30 @@ // Local space should be same. EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } + { + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2]. + // Normalization test. + fac1.addLocalFloorDiv({3, 0, 6}, 9); // z = [x + 2 / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + // Normalization test. + fac2.addLocalFloorDiv({2, 2}, 4); // y = [x + 1 / 2]. + fac2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + // 2 divisions should be matched. EXPECT_EQ(fac1.getNumLocalIds(), 2u); EXPECT_EQ(fac2.getNumLocalIds(), 2u);