diff --git a/mlir/include/mlir/Analysis/Presburger/MPInt.h b/mlir/include/mlir/Analysis/Presburger/MPInt.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/MPInt.h @@ -0,0 +1,580 @@ +//===- MPInt.h - MLIR MPInt Class -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a simple class to represent arbitrary precision signed integers. +// Unlike APInt, one does not have to specify a fixed maximum size, and the +// integer can take on any arbitrary values. This is optimized for small-values +// by providing fast-paths for the cases when the value stored fits in 64-bits. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_MPINT_H +#define MLIR_ANALYSIS_PRESBURGER_MPINT_H + +#include "mlir/Analysis/Presburger/SlowMPInt.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace presburger { + +namespace detail { +// If builtin intrinsics for overflow-checked arithmetic are available, +// use them. Otherwise, call through to LLVM's overflow-checked arithmetic +// functionality. Those functions also have such macro-gated uses of intrinsics, +// however they are not always_inlined, which is important for us to achieve +// high-performance; calling the functions directly would result in a slowdown +// of 1.15x. +__attribute__((always_inline)) inline bool addOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_add_overflow) + return __builtin_add_overflow(x, y, &result); +#else + return llvm::AddOverflow(x, y, result); +#endif +} +__attribute__((always_inline)) inline bool subOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_sub_overflow) + return __builtin_sub_overflow(x, y, &result); +#else + return llvm::subOverflow(x, y, result); +#endif +} +__attribute__((always_inline)) inline bool mulOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_mul_overflow) + return __builtin_mul_overflow(x, y, &result); +#else + return llvm::MulOverflow(x, y, result); +#endif +} +} // namespace detail + +/// This class provides support for multi-precision arithmetic. +/// +/// Unlike APInt, this extends the precision as necessary to prevent overflows +/// and supports operations between objects with differing internal precisions. +/// +/// This is optimized for small-values by providing fast-paths for the cases +/// when the value stored fits in 64-bits. We annotate all fastpaths by using +/// the LLVM_LIKELY/LLVM_UNLIKELY annotations. Removing these would result in +/// a 1.2x performance slowdown. +/// +/// We always_inline all operations; removing these results in a 1.5x +/// performance slowdown. +/// +/// When holdsSlow is true, a SlowMPInt is held in the union. If it is false, +/// the int64_t is held. Using std::variant instead would significantly impact +/// performance. +class MPInt { +private: + union { + int64_t val64; + detail::SlowMPInt valSlow; + }; + unsigned holdsSlow; + + __attribute__((always_inline)) void init64(int64_t o) { + if (LLVM_UNLIKELY(isLarge())) + valSlow.detail::SlowMPInt::~SlowMPInt(); + val64 = o; + holdsSlow = false; + } + __attribute__((always_inline)) void initAP(const detail::SlowMPInt &o) { + if (LLVM_LIKELY(isSmall())) { + // The data in memory could be in an arbitrary state, not necessarily + // corresponding to any valid state of valSlow; we cannot call any member + // functions, e.g. the assignment operator on it, as they may access the + // invalid internal state. We instead construct a new object using + // placement new. + new (&valSlow) detail::SlowMPInt(o); + } else { + // In this case, we need to use the assignment operator, because if we use + // placement-new as above we would lose track of allocated memory + // and leak it. + valSlow = o; + } + holdsSlow = true; + } + + __attribute__((always_inline)) explicit MPInt(const detail::SlowMPInt &val) + : valSlow(val), holdsSlow(true) {} + __attribute__((always_inline)) bool isSmall() const { return !holdsSlow; } + __attribute__((always_inline)) bool isLarge() const { return holdsSlow; } + __attribute__((always_inline)) int64_t get64() const { + assert(isSmall()); + return val64; + } + __attribute__((always_inline)) int64_t &get64() { + assert(isSmall()); + return val64; + } + __attribute__((always_inline)) const detail::SlowMPInt &getAP() const { + assert(isLarge()); + return valSlow; + } + __attribute__((always_inline)) detail::SlowMPInt &getAP() { + assert(isLarge()); + return valSlow; + } + explicit operator detail::SlowMPInt() const { + if (isSmall()) + return detail::SlowMPInt(get64()); + return getAP(); + } + __attribute__((always_inline)) detail::SlowMPInt getAsAP() const { + return detail::SlowMPInt(*this); + } + +public: + __attribute__((always_inline)) explicit MPInt(int64_t val) + : val64(val), holdsSlow(false) {} + __attribute__((always_inline)) MPInt() : MPInt(0) {} + __attribute__((always_inline)) ~MPInt() { + if (LLVM_UNLIKELY(isLarge())) + valSlow.detail::SlowMPInt::~SlowMPInt(); + } + __attribute__((always_inline)) MPInt(const MPInt &o) + : val64(o.val64), holdsSlow(false) { + if (LLVM_UNLIKELY(o.isLarge())) + initAP(o.valSlow); + } + __attribute__((always_inline)) MPInt &operator=(const MPInt &o) { + if (LLVM_LIKELY(o.isSmall())) { + init64(o.val64); + return *this; + } + initAP(o.valSlow); + return *this; + } + __attribute__((always_inline)) MPInt &operator=(int x) { + init64(x); + return *this; + } + __attribute__((always_inline)) explicit operator int64_t() const { + if (isSmall()) + return get64(); + return static_cast(getAP()); + } + + bool operator==(const MPInt &o) const; + bool operator!=(const MPInt &o) const; + bool operator>(const MPInt &o) const; + bool operator<(const MPInt &o) const; + bool operator<=(const MPInt &o) const; + bool operator>=(const MPInt &o) const; + MPInt operator+(const MPInt &o) const; + MPInt operator-(const MPInt &o) const; + MPInt operator*(const MPInt &o) const; + MPInt operator/(const MPInt &o) const; + MPInt operator%(const MPInt &o) const; + MPInt &operator+=(const MPInt &o); + MPInt &operator-=(const MPInt &o); + MPInt &operator*=(const MPInt &o); + MPInt &operator/=(const MPInt &o); + MPInt &operator%=(const MPInt &o); + MPInt operator-() const; + MPInt &operator++(); + MPInt &operator--(); + + // Divide by a number that is known to be positive. + // This is slightly more efficient because it saves an overflow check. + MPInt divByPositive(const MPInt &o) const; + MPInt &divByPositiveInPlace(const MPInt &o); + + friend MPInt abs(const MPInt &x); + friend MPInt gcdRange(ArrayRef range); + friend MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs); + friend MPInt floorDiv(const MPInt &lhs, const MPInt &rhs); + friend MPInt gcd(const MPInt &a, const MPInt &b); + friend MPInt lcm(const MPInt &a, const MPInt &b); + friend MPInt mod(const MPInt &lhs, const MPInt &rhs); + + llvm::raw_ostream &print(llvm::raw_ostream &os) const; + void dump() const; + + /// --------------------------------------------------------------------------- + /// Convenience operator overloads for int64_t. + /// --------------------------------------------------------------------------- + friend MPInt &operator+=(MPInt &a, int64_t b); + friend MPInt &operator-=(MPInt &a, int64_t b); + friend MPInt &operator*=(MPInt &a, int64_t b); + friend MPInt &operator/=(MPInt &a, int64_t b); + friend MPInt &operator%=(MPInt &a, int64_t b); + + friend bool operator==(const MPInt &a, int64_t b); + friend bool operator!=(const MPInt &a, int64_t b); + friend bool operator>(const MPInt &a, int64_t b); + friend bool operator<(const MPInt &a, int64_t b); + friend bool operator<=(const MPInt &a, int64_t b); + friend bool operator>=(const MPInt &a, int64_t b); + friend MPInt operator+(const MPInt &a, int64_t b); + friend MPInt operator-(const MPInt &a, int64_t b); + friend MPInt operator*(const MPInt &a, int64_t b); + friend MPInt operator/(const MPInt &a, int64_t b); + friend MPInt operator%(const MPInt &a, int64_t b); + + friend bool operator==(int64_t a, const MPInt &b); + friend bool operator!=(int64_t a, const MPInt &b); + friend bool operator>(int64_t a, const MPInt &b); + friend bool operator<(int64_t a, const MPInt &b); + friend bool operator<=(int64_t a, const MPInt &b); + friend bool operator>=(int64_t a, const MPInt &b); + friend MPInt operator+(int64_t a, const MPInt &b); + friend MPInt operator-(int64_t a, const MPInt &b); + friend MPInt operator*(int64_t a, const MPInt &b); + friend MPInt operator/(int64_t a, const MPInt &b); + friend MPInt operator%(int64_t a, const MPInt &b); + + friend llvm::hash_code hash_value(const MPInt &x); // NOLINT +}; + +/// This just calls through to the operator int64_t, but it's useful when a +/// function pointer is required. (Although this is marked inline, it is still +/// possible to obtain and use a function pointer to this.) +__attribute__((always_inline)) inline int64_t int64FromMPInt(const MPInt &x) { + return int64_t(x); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const MPInt &x); + +// The RHS is always expected to be positive, and the result +/// is always non-negative. +MPInt mod(const MPInt &lhs, const MPInt &rhs); + +/// We define the operations here in the header to facilitate inlining. + +/// --------------------------------------------------------------------------- +/// Comparison operators. +/// --------------------------------------------------------------------------- +__attribute__((always_inline)) inline bool +MPInt::operator==(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return get64() == o.get64(); + return getAsAP() == o.getAsAP(); +} +__attribute__((always_inline)) inline bool +MPInt::operator!=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return get64() != o.get64(); + return getAsAP() != o.getAsAP(); +} +__attribute__((always_inline)) inline bool +MPInt::operator>(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return get64() > o.get64(); + return getAsAP() > o.getAsAP(); +} +__attribute__((always_inline)) inline bool +MPInt::operator<(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return get64() < o.get64(); + return getAsAP() < o.getAsAP(); +} +__attribute__((always_inline)) inline bool +MPInt::operator<=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return get64() <= o.get64(); + return getAsAP() <= o.getAsAP(); +} +__attribute__((always_inline)) inline bool +MPInt::operator>=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return get64() >= o.get64(); + return getAsAP() >= o.getAsAP(); +} + +/// --------------------------------------------------------------------------- +/// Arithmetic operators. +/// --------------------------------------------------------------------------- +__attribute__((always_inline)) inline MPInt +MPInt::operator+(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = detail::addOverflow(get64(), o.get64(), result.get64()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(getAsAP() + o.getAsAP()); + } + return MPInt(getAsAP() + o.getAsAP()); +} +__attribute__((always_inline)) inline MPInt +MPInt::operator-(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = detail::subOverflow(get64(), o.get64(), result.get64()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(getAsAP() - o.getAsAP()); + } + return MPInt(getAsAP() - o.getAsAP()); +} +__attribute__((always_inline)) inline MPInt +MPInt::operator*(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = detail::mulOverflow(get64(), o.get64(), result.get64()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(getAsAP() * o.getAsAP()); + } + return MPInt(getAsAP() * o.getAsAP()); +} + +__attribute__((always_inline)) inline MPInt +MPInt::divByPositive(const MPInt &o) const { + assert(o > 0); + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return MPInt(get64() / o.get64()); + return MPInt(getAsAP() / o.getAsAP()); +} + +__attribute__((always_inline)) inline MPInt +MPInt::operator/(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + if (LLVM_UNLIKELY(o.get64() == -1)) + return -*this; + return MPInt(get64() / o.get64()); + } + return MPInt(getAsAP() / o.getAsAP()); +} + +inline MPInt abs(const MPInt &x) { return MPInt(x >= 0 ? x : -x); } +inline MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) { + if (rhs == -1) + return -lhs; + int64_t x = (rhs.get64() > 0) ? -1 : 1; + return MPInt(((lhs.get64() != 0) && (lhs.get64() > 0) == (rhs.get64() > 0)) + ? ((lhs.get64() + x) / rhs.get64()) + 1 + : -(-lhs.get64() / rhs.get64())); + } + return MPInt(ceilDiv(lhs.getAsAP(), rhs.getAsAP())); +} +inline MPInt floorDiv(const MPInt &lhs, const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) { + if (rhs == -1) + return -lhs; + int64_t x = (rhs.get64() < 0) ? 1 : -1; + return MPInt( + ((lhs.get64() != 0) && ((lhs.get64() < 0) != (rhs.get64() < 0))) + ? -((-lhs.get64() + x) / rhs.get64()) - 1 + : lhs.get64() / rhs.get64()); + } + return MPInt(floorDiv(lhs.getAsAP(), rhs.getAsAP())); +} +// The RHS is always expected to be positive, and the result +/// is always non-negative. +inline MPInt mod(const MPInt &lhs, const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) + return MPInt(lhs.get64() % rhs.get64() < 0 + ? lhs.get64() % rhs.get64() + rhs.get64() + : lhs.get64() % rhs.get64()); + return MPInt(mod(lhs.getAsAP(), rhs.getAsAP())); +} + +__attribute__((always_inline)) inline MPInt gcd(const MPInt &a, + const MPInt &b) { + // TODO: fix unsigned/signed overflow issues + if (LLVM_LIKELY(a.isSmall() && b.isSmall())) + return MPInt(int64_t(llvm::GreatestCommonDivisor64(a.get64(), b.get64()))); + return MPInt(gcd(a.getAsAP(), b.getAsAP())); +} + +/// Returns the least common multiple of 'a' and 'b'. +inline MPInt lcm(const MPInt &a, const MPInt &b) { + MPInt x = abs(a); + MPInt y = abs(b); + return (x * y) / gcd(x, y); +} + +/// This operation cannot overflow. +inline MPInt MPInt::operator%(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return MPInt(get64() % o.get64()); + return MPInt(getAsAP() % o.getAsAP()); +} + +inline MPInt MPInt::operator-() const { + if (LLVM_LIKELY(isSmall())) { + if (LLVM_LIKELY(get64() != std::numeric_limits::min())) + return MPInt(-get64()); + return MPInt(-getAsAP()); + } + return MPInt(-getAsAP()); +} + +/// --------------------------------------------------------------------------- +/// Assignment operators, preincrement, predecrement. +/// --------------------------------------------------------------------------- +__attribute__((always_inline)) inline MPInt &MPInt::operator+=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = get64(); + bool overflow = detail::addOverflow(get64(), o.get64(), result); + if (LLVM_LIKELY(!overflow)) { + get64() = result; + return *this; + } + return *this = MPInt(getAsAP() + o.getAsAP()); + } + return *this = MPInt(getAsAP() + o.getAsAP()); +} +__attribute__((always_inline)) inline MPInt &MPInt::operator-=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = get64(); + bool overflow = detail::subOverflow(get64(), o.get64(), result); + if (LLVM_LIKELY(!overflow)) { + get64() = result; + return *this; + } + return *this = MPInt(getAsAP() - o.getAsAP()); + } + return *this = MPInt(getAsAP() - o.getAsAP()); +} +__attribute__((always_inline)) inline MPInt &MPInt::operator*=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = get64(); + bool overflow = detail::mulOverflow(get64(), o.get64(), result); + if (LLVM_LIKELY(!overflow)) { + get64() = result; + return *this; + } + return *this = MPInt(getAsAP() * o.getAsAP()); + } + return *this = MPInt(getAsAP() * o.getAsAP()); +} +__attribute__((always_inline)) inline MPInt &MPInt::operator/=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + if (LLVM_UNLIKELY(o.get64() == -1)) + return *this = -*this; + get64() /= o.get64(); + return *this; + } + return *this = MPInt(getAsAP() / o.getAsAP()); +} + +__attribute__((always_inline)) inline MPInt & +MPInt::divByPositiveInPlace(const MPInt &o) { + assert(o > 0); + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + get64() /= o.get64(); + return *this; + } + return *this = MPInt(getAsAP() / o.getAsAP()); +} + +__attribute__((always_inline)) inline MPInt &MPInt::operator%=(const MPInt &o) { + *this = *this % o; + return *this; +} +__attribute__((always_inline)) inline MPInt &MPInt::operator++() { + *this += 1; + return *this; +} +__attribute__((always_inline)) inline MPInt &MPInt::operator--() { + *this -= 1; + return *this; +} + +/// ---------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// ---------------------------------------------------------------------------- +__attribute__((always_inline)) inline MPInt &operator+=(MPInt &a, int64_t b) { + return a = a + b; +} +__attribute__((always_inline)) inline MPInt &operator-=(MPInt &a, int64_t b) { + return a = a - b; +} +__attribute__((always_inline)) inline MPInt &operator*=(MPInt &a, int64_t b) { + return a = a * b; +} +__attribute__((always_inline)) inline MPInt &operator/=(MPInt &a, int64_t b) { + return a = a / b; +} +__attribute__((always_inline)) inline MPInt &operator%=(MPInt &a, int64_t b) { + return a = a % b; +} +inline MPInt operator+(const MPInt &a, int64_t b) { return a + MPInt(b); } +inline MPInt operator-(const MPInt &a, int64_t b) { return a - MPInt(b); } +inline MPInt operator*(const MPInt &a, int64_t b) { return a * MPInt(b); } +inline MPInt operator/(const MPInt &a, int64_t b) { return a / MPInt(b); } +inline MPInt operator%(const MPInt &a, int64_t b) { return a % MPInt(b); } +inline MPInt operator+(int64_t a, const MPInt &b) { return MPInt(a) + b; } +inline MPInt operator-(int64_t a, const MPInt &b) { return MPInt(a) - b; } +inline MPInt operator*(int64_t a, const MPInt &b) { return MPInt(a) * b; } +inline MPInt operator/(int64_t a, const MPInt &b) { return MPInt(a) / b; } +inline MPInt operator%(int64_t a, const MPInt &b) { return MPInt(a) % b; } + +/// We provide special implementations of the comparison operators rather than +/// calling through as above, as this would result in a 1.2x slowdown. +inline bool operator==(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.get64() == b; + return a.getAP() == b; +} +inline bool operator!=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.get64() != b; + return a.getAP() != b; +} +inline bool operator>(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.get64() > b; + return a.getAP() > b; +} +inline bool operator<(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.get64() < b; + return a.getAP() < b; +} +inline bool operator<=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.get64() <= b; + return a.getAP() <= b; +} +inline bool operator>=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.get64() >= b; + return a.getAP() >= b; +} +inline bool operator==(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a == b.get64(); + return a == b.getAP(); +} +inline bool operator!=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a != b.get64(); + return a != b.getAP(); +} +inline bool operator>(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a > b.get64(); + return a > b.getAP(); +} +inline bool operator<(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a < b.get64(); + return a < b.getAP(); +} +inline bool operator<=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a <= b.get64(); + return a <= b.getAP(); +} +inline bool operator>=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a >= b.get64(); + return a >= b.getAP(); +} + +} // namespace presburger +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_MPINT_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -2,6 +2,7 @@ IntegerRelation.cpp LinearTransform.cpp Matrix.cpp + MPInt.cpp PresburgerRelation.cpp PresburgerSpace.cpp PWMAFunction.cpp diff --git a/mlir/lib/Analysis/Presburger/MPInt.cpp b/mlir/lib/Analysis/Presburger/MPInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/MPInt.cpp @@ -0,0 +1,36 @@ +//===- MPInt.cpp - MLIR MPInt Class ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/MPInt.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace presburger; + +llvm::hash_code mlir::presburger::hash_value(const MPInt &x) { + if (x.isSmall()) + return llvm::hash_value(x.val64); + return detail::hash_value(x.valSlow); +} + +/// --------------------------------------------------------------------------- +/// Printing. +/// --------------------------------------------------------------------------- +llvm::raw_ostream &MPInt::print(llvm::raw_ostream &os) const { + if (isSmall()) + return os << val64; + return os << valSlow; +} + +void MPInt::dump() const { print(llvm::errs()); } + +llvm::raw_ostream &mlir::presburger::operator<<(llvm::raw_ostream &os, + const MPInt &x) { + x.print(os); + return os; +} diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -3,6 +3,7 @@ IntegerRelationTest.cpp LinearTransformTest.cpp MatrixTest.cpp + MPIntTest.cpp PresburgerSetTest.cpp PresburgerSpaceTest.cpp PWMAFunctionTest.cpp diff --git a/mlir/unittests/Analysis/Presburger/MPIntTest.cpp b/mlir/unittests/Analysis/Presburger/MPIntTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/MPIntTest.cpp @@ -0,0 +1,111 @@ +//===- MPIntTest.cpp - Tests for MPInt ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/MPInt.h" +#include +#include + +using namespace mlir; +using namespace presburger; + +TEST(MPIntTest, ops) { + MPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + five, ten); + EXPECT_EQ(five * five, 2 * ten + five); + EXPECT_EQ(five * five, 3 * ten - five); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(five / two, two); + EXPECT_EQ(five % two, two / two); + + EXPECT_EQ(-ten % seven, -10 % 7); + EXPECT_EQ(ten % -seven, 10 % -7); + EXPECT_EQ(-ten % -seven, -10 % -7); + EXPECT_EQ(ten % seven, 10 % 7); + + EXPECT_EQ(-ten / seven, -10 / 7); + EXPECT_EQ(ten / -seven, 10 / -7); + EXPECT_EQ(-ten / -seven, -10 / -7); + EXPECT_EQ(ten / seven, 10 / 7); + + MPInt x = ten; + x += five; + EXPECT_EQ(x, 15); + x *= two; + EXPECT_EQ(x, 30); + x /= seven; + EXPECT_EQ(x, 4); + x -= two * 10; + EXPECT_EQ(x, -16); + x *= 2 * two; + EXPECT_EQ(x, -64); + x /= two / -2; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, ten); + EXPECT_GE(ten, ten); + EXPECT_EQ(ten, ten); + EXPECT_FALSE(ten != ten); + EXPECT_FALSE(ten < ten); + EXPECT_FALSE(ten > ten); + EXPECT_LT(five, ten); + EXPECT_GT(ten, five); +} + +TEST(MPIntTest, ops64Overloads) { + MPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + 5, ten); + EXPECT_EQ(five + 5, 5 + five); + EXPECT_EQ(five * 5, 2 * ten + 5); + EXPECT_EQ(five * 5, 3 * ten - 5); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(5 / two, 2); + EXPECT_EQ(five / 2, 2); + EXPECT_EQ(2 % two, 0); + EXPECT_EQ(2 - two, 0); + EXPECT_EQ(2 % two, two % 2); + + MPInt x = ten; + x += 5; + EXPECT_EQ(x, 15); + x *= 2; + EXPECT_EQ(x, 30); + x /= 7; + EXPECT_EQ(x, 4); + x -= 20; + EXPECT_EQ(x, -16); + x *= 4; + EXPECT_EQ(x, -64); + x /= -1; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, 10); + EXPECT_GE(ten, 10); + EXPECT_EQ(ten, 10); + EXPECT_FALSE(ten != 10); + EXPECT_FALSE(ten < 10); + EXPECT_FALSE(ten > 10); + EXPECT_LT(five, 10); + EXPECT_GT(ten, 5); + + EXPECT_LE(10, ten); + EXPECT_GE(10, ten); + EXPECT_EQ(10, ten); + EXPECT_FALSE(10 != ten); + EXPECT_FALSE(10 < ten); + EXPECT_FALSE(10 > ten); + EXPECT_LT(5, ten); + EXPECT_GT(10, five); +} + +TEST(MPIntTest, overflows) { + MPInt x(1ll << 60); + EXPECT_EQ((x * x - x * x * x * x) / (x * x * x), 1 - (1ll << 60)); + MPInt y(1ll << 62); + EXPECT_EQ((y + y + y + y + y + y) / y, 6); + EXPECT_EQ(-(2 * (-y)), 2 * y); // -(-2^63) overflow. +}