diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -156,6 +156,11 @@ denoms[i] = divisor; } + // Find the greatest common divisor (GCD) of the dividends and divisor for + // each valid division. Divide the dividends and divisor by the GCD to + // simplify the expression. + void normalizeDivs(); + void insertDiv(unsigned pos, ArrayRef dividend, const MPInt &divisor); void insertDiv(unsigned pos, unsigned num = 1); diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -437,6 +437,7 @@ // variable at position `i` only depends on local variables at position < // `i`. This would make sure that all divisions depending on other local // variables that can be merged, are merged. + normalizeDivs(); for (unsigned i = 0; i < getNumDivs(); ++i) { // Check if a division representation exists for the `i^th` local var. if (denoms[i] == 0) @@ -472,6 +473,14 @@ } } +void DivisionRepr::normalizeDivs() { + for (unsigned i = 0, e = getNumDivs(); i < e; ++i) { + if (getDenom(i) == 0 || getDividend(i).empty()) + continue; + normalizeDiv(getDividend(i), getDenom(i)); + } +} + void DivisionRepr::insertDiv(unsigned pos, ArrayRef dividend, const MPInt &divisor) { assert(pos <= getNumDivs() && "Invalid insertion position"); diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -10,6 +10,7 @@ PresburgerSpaceTest.cpp PWMAFunctionTest.cpp SimplexTest.cpp + UtilsTest.cpp ) target_link_libraries(MLIRPresburgerTests diff --git a/mlir/unittests/Analysis/Presburger/UtilsTest.cpp b/mlir/unittests/Analysis/Presburger/UtilsTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/UtilsTest.cpp @@ -0,0 +1,68 @@ +//===- Utils.cpp - Tests for Utils file ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/Utils.h" + +#include +#include + +using namespace mlir; +using namespace presburger; + +static DivisionRepr parseDivisionRepr(unsigned numVars, unsigned numDivs, + ArrayRef> dividends, + ArrayRef divisors) { + DivisionRepr repr(numVars, numDivs); + for (unsigned i = 0, rows = dividends.size(); i < rows; ++i) + repr.setDiv(i, dividends[i], divisors[i]); + return repr; +} + +static void checkEqual(DivisionRepr &a, DivisionRepr &b) { + EXPECT_EQ(a.getNumVars(), b.getNumVars()); + EXPECT_EQ(a.getNumDivs(), b.getNumDivs()); + for (unsigned i = 0, rows = a.getNumDivs(); i < rows; ++i) { + EXPECT_EQ(a.hasRepr(i), b.hasRepr(i)); + if (!a.hasRepr(i)) + continue; + EXPECT_TRUE(a.getDenom(i) == b.getDenom(i)); + EXPECT_TRUE(a.getDividend(i).equals(b.getDividend(i))); + } +} + +TEST(UtilsTest, ParseAndCompareDivisionReprTest) { + auto merge = [this](unsigned i, unsigned j) -> bool { return true; }; + DivisionRepr a = parseDivisionRepr(1, 1, {{MPInt(1), MPInt(2)}}, {MPInt(2)}), + b = parseDivisionRepr(1, 1, {{MPInt(1), MPInt(2)}}, {MPInt(2)}), + c = parseDivisionRepr(2, 2, + {{MPInt(0), MPInt(1), MPInt(2)}, + {MPInt(0), MPInt(1), MPInt(2)}}, + {MPInt(2), MPInt(2)}); + c.removeDuplicateDivs(merge); + checkEqual(a, b); + checkEqual(a, c); +} + +TEST(UtilsTest, DivisionReprNormalizeTest) { + auto merge = [this](unsigned i, unsigned j) -> bool { return true; }; + DivisionRepr a = parseDivisionRepr(2, 1, {{MPInt(1), MPInt(2), MPInt(-1)}}, + {MPInt(2)}), + b = parseDivisionRepr(2, 1, {{MPInt(16), MPInt(32), MPInt(-16)}}, + {MPInt(32)}), + c = parseDivisionRepr(1, 1, {{MPInt(12), MPInt(-4)}}, + {MPInt(8)}), + d = parseDivisionRepr(2, 2, + {{MPInt(1), MPInt(2), MPInt(-1)}, + {MPInt(4), MPInt(8), MPInt(-4)}}, + {MPInt(2), MPInt(8)}); + b.removeDuplicateDivs(merge); + c.removeDuplicateDivs(merge); + d.removeDuplicateDivs(merge); + checkEqual(a, b); + checkEqual(c, d); +}