diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1352,6 +1352,35 @@ return true; } +/// Normalize a division's `dividend` and the `divisor` by their GCD. For +/// example: if the dividend and divisor are [2,0,4] and 4 respectively, +/// they get normalized to [1,0,2] and 2. +static void normalizeDivisionsByGCD(SmallVectorImpl ÷nd, + unsigned &divisor) { + if (divisor == 0 || dividend.empty()) + return; + int64_t gcd = + llvm::greatestCommonDivisor(dividend.front(), int64_t(divisor)); + + // The reason for ignoring the constant term is as follows. + // For a division: + // floor((a + m.f(x))/(m.d)) + // It can be replaced by: + // floor((floor(a/m) + f(x))/d) + // Since `{a/m}/d` in the dividend satisfies 0 <= {a/m}/d < 1/d, it will not + // influence the result of the floor division and thus, can be ignored. + for (size_t i = 1, m = dividend.size() - 1; i < m; i++) { + gcd = llvm::greatestCommonDivisor(dividend[i], gcd); + if (gcd == 1) + return; + } + + // Normalize the dividend and the denominator. + std::transform(dividend.begin(), dividend.end(), dividend.begin(), + [&gcd](int64_t &n) { return floor(n / gcd); }); + divisor /= gcd; +} + /// Check if the pos^th identifier can be represented as a division using upper /// bound inequality at position `ubIneq` and lower bound inequality at position /// `lbIneq`. @@ -1375,7 +1404,8 @@ /// q = (i + j + 1) floordiv 4 /// /// If successful, `expr` is set to dividend of the division and `divisor` is -/// set to the denominator of the division. +/// set to the denominator of the division. The final division expression is +/// normalized by GCD. static LogicalResult getDivRepr(const FlatAffineConstraints &cst, unsigned pos, unsigned ubIneq, unsigned lbIneq, SmallVector &expr, @@ -1419,6 +1449,7 @@ // Set divisor. divisor = denominator; + normalizeDivisionsByGCD(expr, divisor); return success(); } diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -691,9 +691,9 @@ fac.addLocalFloorDiv({0, 0, 0, 0, 10}, 30); fac.addLocalFloorDiv({0, 0, 0, 0, 0, 99}, 101); - std::vector> divisions = {{0, 0, 0, 0, 0, 0, 10}, - {0, 0, 0, 0, 0, 0, 99}}; - SmallVector denoms = {30, 101}; + std::vector> divisions = {{0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0}}; + SmallVector denoms = {1, 1}; // Check if floordivs with constant numerator can be computed. checkDivisionRepresentation(fac, divisions, denoms); @@ -813,6 +813,30 @@ EXPECT_EQ(fac1.getNumLocalIds(), 2u); EXPECT_EQ(fac2.getNumLocalIds(), 2u); } + + { + // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). + FlatAffineConstraints fac1(1, 0, 1); + // This division would be normalized. + fac1.addLocalFloorDiv({3, 0, 0}, 6); // y = [3x / 6] -> [x/2]. + fac1.addEquality({1, 0, -3, 0}); // x = 3z. + fac1.addInequality({1, 1, 0, 1}); // x + y + 1 >= 0. + + // (x) : (exists y = [x / 2], z : x = 5y). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addEquality({1, -5, 0}); // x = 5y. + fac2.appendLocalId(); // Add local id z. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 1 division should be matched + 2 unmatched local ids. + EXPECT_EQ(fac1.getNumLocalIds(), 3u); + EXPECT_EQ(fac2.getNumLocalIds(), 3u); + } } TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) { @@ -863,6 +887,29 @@ EXPECT_EQ(fac1.getNumLocalIds(), 3u); EXPECT_EQ(fac2.getNumLocalIds(), 3u); } + { + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + // This division would be normalized. + fac1.addLocalFloorDiv({2, 0}, 4); // y = [2x / 4] -> [x / 2]. + fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. + fac2.addLocalFloorDiv({3, 3, 0}, 9); // z = [3x + 3y / 9] -> [x + y / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } } TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) { @@ -884,6 +931,30 @@ // Local space should be same. EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + // 2 divisions should be matched. + EXPECT_EQ(fac1.getNumLocalIds(), 2u); + EXPECT_EQ(fac2.getNumLocalIds(), 2u); + } + { + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x). + FlatAffineConstraints fac1(1); + fac1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2]. + // Normalization test. + fac1.addLocalFloorDiv({3, 0, 6}, 9); // z = [3x + 6 / 9] -> [x + 2 / 3]. + fac1.addInequality({-1, 1, 1, 0}); // y + z >= x. + + // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). + FlatAffineConstraints fac2(1); + // Normalization test. + fac2.addLocalFloorDiv({2, 2}, 4); // y = [2x + 2 / 4] -> [x + 1 / 2]. + fac2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. + fac2.addInequality({1, -1, -1, 0}); // y + z <= x. + + fac1.mergeLocalIds(fac2); + + // Local space should be same. + EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds()); + // 2 divisions should be matched. EXPECT_EQ(fac1.getNumLocalIds(), 2u); EXPECT_EQ(fac2.getNumLocalIds(), 2u);