diff --git a/mlir/include/mlir/Analysis/Presburger/MPInt.h b/mlir/include/mlir/Analysis/Presburger/MPInt.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/MPInt.h @@ -0,0 +1,271 @@ +//===- MPInt.h - MLIR MPInt Class -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a simple class to represent arbitrary precision signed integers. +// Unlike APInt, one does not have to specify a fixed maximum size, and the +// integer can take on any aribtrary values. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_MPINT_H +#define MLIR_ANALYSIS_PRESBURGER_MPINT_H + +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace presburger { + +/// This class provides support for multi-precision arithmetic. +/// +/// Unlike APInt, this extends the precision as necessary to prevent overflows +/// and supports operations between objects with differing internal precisions. +/// +/// Since it uses APInt internally, MPInt (MultiPrecision Integer) stores values +/// in a 64-bit machine integer for small values and uses slower +/// arbitrary-precision arithmetic only for larger values. +class MPInt { +public: + explicit MPInt(int64_t val) : val(APSInt::get(val)) {} + MPInt() : MPInt(0) {} + explicit MPInt(const APSInt &val) : val(val) {} + MPInt &operator=(int64_t val) { return *this = MPInt(val); } + explicit operator int64_t() const { return val.getSExtValue(); } + MPInt operator-() const; + bool operator==(const MPInt &o) const; + bool operator!=(const MPInt &o) const; + bool operator>(const MPInt &o) const; + bool operator<(const MPInt &o) const; + bool operator<=(const MPInt &o) const; + bool operator>=(const MPInt &o) const; + MPInt operator+(const MPInt &o) const; + MPInt operator-(const MPInt &o) const; + MPInt operator*(const MPInt &o) const; + MPInt operator/(const MPInt &o) const; + MPInt operator%(const MPInt &o) const; + MPInt &operator+=(const MPInt &o); + MPInt &operator-=(const MPInt &o); + MPInt &operator*=(const MPInt &o); + MPInt &operator/=(const MPInt &o); + MPInt &operator%=(const MPInt &o); + + MPInt &operator++(); + MPInt &operator--(); + + friend MPInt abs(const MPInt &x); + friend MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs); + friend MPInt floorDiv(const MPInt &lhs, const MPInt &rhs); + friend MPInt greatestCommonDivisor(const MPInt &a, const MPInt &b); + /// Overload to compute a hash_code for a MPInt value. + friend llvm::hash_code hash_value(const MPInt &x); // NOLINT + + llvm::raw_ostream &print(llvm::raw_ostream &os) const; + void dump() const; + +private: + unsigned getBitWidth() const { return val.getBitWidth(); } + + // The held integer value. + // + // TODO: consider using APInt directly to avoid unnecessary repeated internal + // signedness checks. This requires refactoring, exposing, or duplicating + // APSInt::compareValues. + APSInt val; +}; + +/// This just calls through to the operator int64_t, but it's useful when a +/// function pointer is required. +inline int64_t int64FromMPInt(const MPInt &x) { return int64_t(x); } + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const MPInt &x); + +// The RHS is always expected to be positive, and the result +/// is always non-negative. +MPInt mod(const MPInt &lhs, const MPInt &rhs); + +/// Returns the least common multiple of 'a' and 'b'. +MPInt lcm(const MPInt &a, const MPInt &b); + +/// --------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// --------------------------------------------------------------------------- +inline MPInt &operator+=(MPInt &a, int64_t b) { return a += MPInt(b); } +inline MPInt &operator-=(MPInt &a, int64_t b) { return a -= MPInt(b); } +inline MPInt &operator*=(MPInt &a, int64_t b) { return a *= MPInt(b); } +inline MPInt &operator/=(MPInt &a, int64_t b) { return a /= MPInt(b); } +inline MPInt &operator%=(MPInt &a, int64_t b) { return a %= MPInt(b); } + +inline bool operator==(const MPInt &a, int64_t b) { return a == MPInt(b); } +inline bool operator!=(const MPInt &a, int64_t b) { return a != MPInt(b); } +inline bool operator>(const MPInt &a, int64_t b) { return a > MPInt(b); } +inline bool operator<(const MPInt &a, int64_t b) { return a < MPInt(b); } +inline bool operator<=(const MPInt &a, int64_t b) { return a <= MPInt(b); } +inline bool operator>=(const MPInt &a, int64_t b) { return a >= MPInt(b); } +inline MPInt operator+(const MPInt &a, int64_t b) { return a + MPInt(b); } +inline MPInt operator-(const MPInt &a, int64_t b) { return a - MPInt(b); } +inline MPInt operator*(const MPInt &a, int64_t b) { return a * MPInt(b); } +inline MPInt operator/(const MPInt &a, int64_t b) { return a / MPInt(b); } +inline MPInt operator%(const MPInt &a, int64_t b) { return a % MPInt(b); } + +inline bool operator==(int64_t a, const MPInt &b) { return MPInt(a) == b; } +inline bool operator!=(int64_t a, const MPInt &b) { return MPInt(a) != b; } +inline bool operator>(int64_t a, const MPInt &b) { return MPInt(a) > b; } +inline bool operator<(int64_t a, const MPInt &b) { return MPInt(a) < b; } +inline bool operator<=(int64_t a, const MPInt &b) { return MPInt(a) <= b; } +inline bool operator>=(int64_t a, const MPInt &b) { return MPInt(a) >= b; } +inline MPInt operator+(int64_t a, const MPInt &b) { return MPInt(a) + b; } +inline MPInt operator-(int64_t a, const MPInt &b) { return MPInt(a) - b; } +inline MPInt operator*(int64_t a, const MPInt &b) { return MPInt(a) * b; } +inline MPInt operator/(int64_t a, const MPInt &b) { return MPInt(a) / b; } +inline MPInt operator%(int64_t a, const MPInt &b) { return MPInt(a) % b; } + +/// We define the operations here in the header to facilitate inlining. + +/// --------------------------------------------------------------------------- +/// Comparison operators. +/// --------------------------------------------------------------------------- +inline bool MPInt::operator==(const MPInt &o) const { + return APSInt::compareValues(val, o.val) == 0; +} +inline bool MPInt::operator!=(const MPInt &o) const { + return APSInt::compareValues(val, o.val) != 0; +} +inline bool MPInt::operator>(const MPInt &o) const { + return APSInt::compareValues(val, o.val) > 0; +} +inline bool MPInt::operator<(const MPInt &o) const { + return APSInt::compareValues(val, o.val) < 0; +} +inline bool MPInt::operator<=(const MPInt &o) const { + return APSInt::compareValues(val, o.val) <= 0; +} +inline bool MPInt::operator>=(const MPInt &o) const { + return APSInt::compareValues(val, o.val) >= 0; +} + +/// --------------------------------------------------------------------------- +/// Arithmetic operators. +/// --------------------------------------------------------------------------- +namespace detail { +/// Bring a and b to have the same width and then call a.op(b, overflow). +/// If the overflow bit becomes set, resize a and b to double the width and +/// call a.op(b, overflow), returning its result. The operation with double +/// widths should not also overflow. +template +inline APSInt runOpWithExpandOnOverflow(const APInt &a, const APInt &b, const Function &op) { + bool overflow; + unsigned width = std::max(a.getBitWidth(), b.getBitWidth()); + APInt ret = op(a.sextOrSelf(width), b.sextOrSelf(width), overflow); + if (!overflow) + return APSInt(ret, /*isUnsigned=*/false); + + width *= 2; + ret = op(a.sextOrSelf(width), b.sextOrSelf(width), overflow); + assert(!overflow && "double width should be sufficient to avoid overflow!"); + return APSInt(ret, /*isUnsigned=*/false); +} +} // namespace detail + +inline MPInt MPInt::operator+(const MPInt &o) const { + return MPInt(detail::runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov))); +} +inline MPInt MPInt::operator-(const MPInt &o) const { + return MPInt(detail::runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov))); +} +inline MPInt MPInt::operator*(const MPInt &o) const { + return MPInt(detail::runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov))); +} +inline MPInt MPInt::operator/(const MPInt &o) const { + return MPInt(detail::runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov))); +} +inline MPInt abs(const MPInt &x) { return x >= 0 ? x : -x; } +inline MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs) { + if (rhs == -1) + return -lhs; + return MPInt(APSInt(llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::UP), + /*isUnsigned=*/false)); +} +inline MPInt floorDiv(const MPInt &lhs, const MPInt &rhs) { + if (rhs == -1) + return -lhs; + return MPInt(APSInt(llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::DOWN), + /*isUnsigned=*/false)); +} +// The RHS is always expected to be positive, and the result +/// is always non-negative. +inline MPInt mod(const MPInt &lhs, const MPInt &rhs) { + assert(rhs >= 1); + return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs; +} + +inline MPInt greatestCommonDivisor(const MPInt &a, const MPInt &b) { + return MPInt(APSInt(llvm::APIntOps::GreatestCommonDivisor(a.val.abs(), b.val.abs()), + /*isUnsigned=*/false)); +} + +/// Returns the least common multiple of 'a' and 'b'. +inline MPInt lcm(const MPInt &a, const MPInt &b) { + MPInt x = abs(a); + MPInt y = abs(b); + return (x * y) / greatestCommonDivisor(x, y); +} + +/// This operation cannot overflow. +inline MPInt MPInt::operator%(const MPInt &o) const { + unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth()); + return MPInt(APSInt(val.sextOrSelf(width).srem(o.val.sextOrSelf(width)), + /*isUnsigned=*/false)); +} + +inline MPInt MPInt::operator-() const { + if (val.isMinSignedValue()) { + /// Overflow only occurs when the value is the minimum possible value. + APSInt ret = val.extend(2 * val.getBitWidth()); + return MPInt(-ret); + } + return MPInt(-val); +} + +/// --------------------------------------------------------------------------- +/// Assignment operators, preincrement, predecrement. +/// --------------------------------------------------------------------------- +inline MPInt &MPInt::operator+=(const MPInt &o) { + *this = *this + o; + return *this; +} +inline MPInt &MPInt::operator-=(const MPInt &o) { + *this = *this - o; + return *this; +} +inline MPInt &MPInt::operator*=(const MPInt &o) { + *this = *this * o; + return *this; +} +inline MPInt &MPInt::operator/=(const MPInt &o) { + *this = *this / o; + return *this; +} +inline MPInt &MPInt::operator%=(const MPInt &o) { + *this = *this % o; + return *this; +} +inline MPInt &MPInt::operator++() { + *this += 1; + return *this; +} + +inline MPInt &MPInt::operator--() { + *this -= 1; + return *this; +} + +} // namespace presburger +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_MPINT_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -6,6 +6,7 @@ PresburgerSpace.cpp PWMAFunction.cpp Simplex.cpp + MPInt.cpp Utils.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/Analysis/Presburger/MPInt.cpp b/mlir/lib/Analysis/Presburger/MPInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/MPInt.cpp @@ -0,0 +1,32 @@ +//===- MPInt.cpp - MLIR MPInt Class ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/MPInt.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace presburger; + +llvm::hash_code mlir::presburger::hash_value(const MPInt &x) { + return hash_value(x.val); +} + +/// --------------------------------------------------------------------------- +/// Printing. +/// --------------------------------------------------------------------------- +llvm::raw_ostream &MPInt::print(llvm::raw_ostream &os) const { + return os << val; +} + +void MPInt::dump() const { print(llvm::errs()); } + +llvm::raw_ostream &mlir::presburger::operator<<(llvm::raw_ostream &os, + const MPInt &x) { + x.print(os); + return os; +} diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -2,6 +2,7 @@ IntegerPolyhedronTest.cpp LinearTransformTest.cpp MatrixTest.cpp + MPIntTest.cpp PresburgerSetTest.cpp PresburgerSpaceTest.cpp PWMAFunctionTest.cpp diff --git a/mlir/unittests/Analysis/Presburger/MPIntTest.cpp b/mlir/unittests/Analysis/Presburger/MPIntTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/MPIntTest.cpp @@ -0,0 +1,108 @@ +//===- MPIntTest.cpp - Tests for MPInt ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/MPInt.h" +#include +#include + +using namespace mlir; +using namespace presburger; + +TEST(MPIntTest, ops) { + MPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + five, ten); + EXPECT_EQ(five * five, 2 * ten + five); + EXPECT_EQ(five * five, 3 * ten - five); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(five / two, two); + EXPECT_EQ(five % two, two / two); + + EXPECT_EQ(-ten % seven, -10 % 7); + EXPECT_EQ(ten % -seven, 10 % -7); + EXPECT_EQ(-ten % -seven, -10 % -7); + EXPECT_EQ(ten % seven, 10 % 7); + + EXPECT_EQ(-ten / seven, -10 / 7); + EXPECT_EQ(ten / -seven, 10 / -7); + EXPECT_EQ(-ten / -seven, -10 / -7); + EXPECT_EQ(ten / seven, 10 / 7); + + MPInt x = ten; + x += five; + EXPECT_EQ(x, 15); + x *= two; + EXPECT_EQ(x, 30); + x /= seven; + EXPECT_EQ(x, 4); + x -= two * 10; + EXPECT_EQ(x, -16); + x *= 2 * two; + EXPECT_EQ(x, -64); + x /= two / -2; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, ten); + EXPECT_GE(ten, ten); + EXPECT_EQ(ten, ten); + EXPECT_FALSE(ten != ten); + EXPECT_FALSE(ten < ten); + EXPECT_FALSE(ten > ten); + EXPECT_LT(five, ten); + EXPECT_GT(ten, five); +} + +TEST(MPIntTest, ops64Overloads) { + MPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + 5, ten); + EXPECT_EQ(five + 5, 5 + five); + EXPECT_EQ(five * 5, 2 * ten + 5); + EXPECT_EQ(five * 5, 3 * ten - 5); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(5 / two, 2); + EXPECT_EQ(five / 2, 2); + EXPECT_EQ(2 % two, 0); + EXPECT_EQ(2 - two, 0); + EXPECT_EQ(2 % two, two % 2); + + MPInt x = ten; + x += 5; + EXPECT_EQ(x, 15); + x *= 2; + EXPECT_EQ(x, 30); + x /= 7; + EXPECT_EQ(x, 4); + x -= 20; + EXPECT_EQ(x, -16); + x *= 4; + EXPECT_EQ(x, -64); + x /= -1; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, 10); + EXPECT_GE(ten, 10); + EXPECT_EQ(ten, 10); + EXPECT_FALSE(ten != 10); + EXPECT_FALSE(ten < 10); + EXPECT_FALSE(ten > 10); + EXPECT_LT(five, 10); + EXPECT_GT(ten, 5); + + EXPECT_LE(10, ten); + EXPECT_GE(10, ten); + EXPECT_EQ(10, ten); + EXPECT_FALSE(10 != ten); + EXPECT_FALSE(10 < ten); + EXPECT_FALSE(10 > ten); + EXPECT_LT(5, ten); + EXPECT_GT(10, five); +} + +TEST(MPIntTest, overflows) { + MPInt x(1ll << 60); + EXPECT_EQ((x * x - x * x * x * x) / (x * x * x), 1 - (1ll << 60)); +}