diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -156,6 +156,8 @@ denoms[i] = divisor; } + void normalizeDivs(); + void insertDiv(unsigned pos, ArrayRef dividend, const MPInt &divisor); void insertDiv(unsigned pos, unsigned num = 1); diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -431,8 +431,9 @@ llvm::function_ref merge) { // Find and merge duplicate divisions. - // TODO: Add division normalization to support divisions that differ by + // Add division normalization to support divisions that differ by // a constant. + normalizeDivs(); // TODO: Add division ordering such that a division representation for local // variable at position `i` only depends on local variables at position < // `i`. This would make sure that all divisions depending on other local @@ -472,6 +473,16 @@ } } +void DivisionRepr::normalizeDivs() { + for (unsigned i = 0; i < getNumDivs(); ++i) { + if (getDenom(i) == 0 || getDividend(i).empty()) { + continue; + } + normalizeDiv(getDividend(i), getDenom(i)); + } + return; +} + void DivisionRepr::insertDiv(unsigned pos, ArrayRef dividend, const MPInt &divisor) { assert(pos <= getNumDivs() && "Invalid insertion position"); diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -10,6 +10,7 @@ PresburgerSpaceTest.cpp PWMAFunctionTest.cpp SimplexTest.cpp + DivisionReprTest.cpp ) target_link_libraries(MLIRPresburgerTests diff --git a/mlir/unittests/Analysis/Presburger/DivisionReprTest.cpp b/mlir/unittests/Analysis/Presburger/DivisionReprTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/DivisionReprTest.cpp @@ -0,0 +1,70 @@ +//===- DivisionRepr.cpp - Tests for DivisionRepr class ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/Utils.h" + +#include +#include + +using namespace mlir; +using namespace presburger; + +static DivisionRepr parseDivisionRepr(unsigned numVars, unsigned numDivs, + ArrayRef> dividends, + ArrayRef divisors) { + DivisionRepr repr(numVars, numDivs); + for (unsigned i = 0, rows = dividends.size(); i < rows; ++i) { + repr.setDiv(i, dividends[i], divisors[i]); + } + return repr; +} + +static void checkEqual(DivisionRepr &a, DivisionRepr &b) { + EXPECT_EQ(a.getNumVars(), b.getNumVars()); + EXPECT_EQ(a.getNumDivs(), b.getNumDivs()); + for (unsigned i = 0, rows = a.getNumDivs(); i < rows; ++i) { + EXPECT_EQ(a.hasRepr(i), b.hasRepr(i)); + if (!a.hasRepr(i)) { + continue; + } + EXPECT_TRUE(a.getDenom(i) == b.getDenom(i)); + EXPECT_TRUE(a.getDividend(i).equals(b.getDividend(i))); + } +} + +TEST(DivisionReprTest, ParseAndCompareTest) { + auto merge = [this](unsigned i, unsigned j) -> bool { return true; }; + DivisionRepr a = parseDivisionRepr(1, 1, {{MPInt(1), MPInt(2)}}, {MPInt(2)}), + b = parseDivisionRepr(1, 1, {{MPInt(1), MPInt(2)}}, {MPInt(2)}), + c = parseDivisionRepr(2, 2, + {{MPInt(0), MPInt(1), MPInt(2)}, + {MPInt(0), MPInt(1), MPInt(2)}}, + {MPInt(2), MPInt(2)}); + c.removeDuplicateDivs(merge); + checkEqual(a, b); + checkEqual(a, c); +} + +TEST(DivisionReprTest, NormalizeTest) { + auto merge = [this](unsigned i, unsigned j) -> bool { return true; }; + DivisionRepr a = parseDivisionRepr(2, 1, {{MPInt(1), MPInt(2), MPInt(-1)}}, + {MPInt(2)}), + b = parseDivisionRepr(2, 1, {{MPInt(16), MPInt(32), MPInt(-16)}}, + {MPInt(32)}), + c = parseDivisionRepr(1, 1, {{MPInt(12), MPInt(-4)}}, + {MPInt(8)}), + d = parseDivisionRepr(2, 2, + {{MPInt(1), MPInt(2), MPInt(-1)}, + {MPInt(4), MPInt(8), MPInt(-4)}}, + {MPInt(2), MPInt(8)}); + b.removeDuplicateDivs(merge); + c.removeDuplicateDivs(merge); + d.removeDuplicateDivs(merge); + checkEqual(a, b); + checkEqual(c, d); +} \ No newline at end of file