diff --git a/mlir/include/mlir/Analysis/Presburger/MPInt.h b/mlir/include/mlir/Analysis/Presburger/MPInt.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/MPInt.h @@ -0,0 +1,588 @@ +//===- MPInt.h - MLIR MPInt Class -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a simple class to represent arbitrary precision signed integers. +// Unlike APInt, one does not have to specify a fixed maximum size, and the +// integer can take on any arbitrary values. This is optimized for small-values +// by providing fast-paths for the cases when the value stored fits in 64-bits. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_MPINT_H +#define MLIR_ANALYSIS_PRESBURGER_MPINT_H + +#include "mlir/Analysis/Presburger/SlowMPInt.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace presburger { + +namespace detail { +// If builtin intrinsics for overflow-checked arithmetic are available, +// use them. Otherwise, call through to LLVM's overflow-checked arithmetic +// functionality. Those functions also have such macro-gated uses of intrinsics, +// however they are not always_inlined, which is important for us to achieve +// high-performance; calling the functions directly would result in a slowdown +// of 1.15x. +__attribute__((always_inline)) inline bool addOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_add_overflow) + return __builtin_add_overflow(x, y, &result); +#else + return llvm::AddOverflow(x, y, result); +#endif +} +__attribute__((always_inline)) inline bool subOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_sub_overflow) + return __builtin_sub_overflow(x, y, &result); +#else + return llvm::subOverflow(x, y, result); +#endif +} +__attribute__((always_inline)) inline bool mulOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_mul_overflow) + return __builtin_mul_overflow(x, y, &result); +#else + return llvm::MulOverflow(x, y, result); +#endif +} +} // namespace detail + +/// This class provides support for multi-precision arithmetic. +/// +/// Unlike APInt, this extends the precision as necessary to prevent overflows +/// and supports operations between objects with differing internal precisions. +/// +/// This is optimized for small-values by providing fast-paths for the cases +/// when the value stored fits in 64-bits. We annotate all fastpaths by using +/// the LLVM_LIKELY/LLVM_UNLIKELY annotations. Removing these would result in +/// a 1.2x performance slowdown. +/// +/// We always_inline all operations; removing these results in a 1.5x +/// performance slowdown. +/// +/// When holdsLarge is true, a SlowMPInt is held in the union. If it is false, +/// the int64_t is held. Using std::variant instead would lead to significantly +/// worse performance. +class MPInt { +private: + union { + int64_t valSmall; + detail::SlowMPInt valLarge; + }; + unsigned holdsLarge; + + __attribute__((always_inline)) void initSmall(int64_t o) { + if (LLVM_UNLIKELY(isLarge())) + valLarge.detail::SlowMPInt::~SlowMPInt(); + valSmall = o; + holdsLarge = false; + } + __attribute__((always_inline)) void initLarge(const detail::SlowMPInt &o) { + if (LLVM_LIKELY(isSmall())) { + // The data in memory could be in an arbitrary state, not necessarily + // corresponding to any valid state of valLarge; we cannot call any member + // functions, e.g. the assignment operator on it, as they may access the + // invalid internal state. We instead construct a new object using + // placement new. + new (&valLarge) detail::SlowMPInt(o); + } else { + // In this case, we need to use the assignment operator, because if we use + // placement-new as above we would lose track of allocated memory + // and leak it. + valLarge = o; + } + holdsLarge = true; + } + + __attribute__((always_inline)) explicit MPInt(const detail::SlowMPInt &val) + : valLarge(val), holdsLarge(true) {} + __attribute__((always_inline)) bool isSmall() const { return !holdsLarge; } + __attribute__((always_inline)) bool isLarge() const { return holdsLarge; } + /// Get the stored value. For getSmall/Large, + /// the stored value should be small/large. + __attribute__((always_inline)) int64_t getSmall() const { + assert(isSmall() && "Wrong accessor called!"); + return valSmall; + } + __attribute__((always_inline)) int64_t &getSmall() { + assert(isSmall() && "Wrong accessor called!"); + return valSmall; + } + __attribute__((always_inline)) const detail::SlowMPInt &getLarge() const { + assert(isLarge() && "Wrong accessor called!"); + return valLarge; + } + __attribute__((always_inline)) detail::SlowMPInt &getLarge() { + assert(isLarge() && "Wrong accessor called!"); + return valLarge; + } + explicit operator detail::SlowMPInt() const { + if (isSmall()) + return detail::SlowMPInt(getSmall()); + return getLarge(); + } + +public: + __attribute__((always_inline)) explicit MPInt(int64_t val) + : valSmall(val), holdsLarge(false) {} + __attribute__((always_inline)) MPInt() : MPInt(0) {} + __attribute__((always_inline)) ~MPInt() { + if (LLVM_UNLIKELY(isLarge())) + valLarge.detail::SlowMPInt::~SlowMPInt(); + } + __attribute__((always_inline)) MPInt(const MPInt &o) + : valSmall(o.valSmall), holdsLarge(false) { + if (LLVM_UNLIKELY(o.isLarge())) + initLarge(o.valLarge); + } + __attribute__((always_inline)) MPInt &operator=(const MPInt &o) { + if (LLVM_LIKELY(o.isSmall())) { + initSmall(o.valSmall); + return *this; + } + initLarge(o.valLarge); + return *this; + } + __attribute__((always_inline)) MPInt &operator=(int x) { + initSmall(x); + return *this; + } + __attribute__((always_inline)) explicit operator int64_t() const { + if (isSmall()) + return getSmall(); + return static_cast(getLarge()); + } + + bool operator==(const MPInt &o) const; + bool operator!=(const MPInt &o) const; + bool operator>(const MPInt &o) const; + bool operator<(const MPInt &o) const; + bool operator<=(const MPInt &o) const; + bool operator>=(const MPInt &o) const; + MPInt operator+(const MPInt &o) const; + MPInt operator-(const MPInt &o) const; + MPInt operator*(const MPInt &o) const; + MPInt operator/(const MPInt &o) const; + MPInt operator%(const MPInt &o) const; + MPInt &operator+=(const MPInt &o); + MPInt &operator-=(const MPInt &o); + MPInt &operator*=(const MPInt &o); + MPInt &operator/=(const MPInt &o); + MPInt &operator%=(const MPInt &o); + MPInt operator-() const; + MPInt &operator++(); + MPInt &operator--(); + + // Divide by a number that is known to be positive. + // This is slightly more efficient because it saves an overflow check. + MPInt divByPositive(const MPInt &o) const; + MPInt &divByPositiveInPlace(const MPInt &o); + + friend MPInt abs(const MPInt &x); + friend MPInt gcdRange(ArrayRef range); + friend MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs); + friend MPInt floorDiv(const MPInt &lhs, const MPInt &rhs); + friend MPInt gcd(const MPInt &a, const MPInt &b); + friend MPInt lcm(const MPInt &a, const MPInt &b); + friend MPInt mod(const MPInt &lhs, const MPInt &rhs); + + llvm::raw_ostream &print(llvm::raw_ostream &os) const; + void dump() const; + + /// --------------------------------------------------------------------------- + /// Convenience operator overloads for int64_t. + /// --------------------------------------------------------------------------- + friend MPInt &operator+=(MPInt &a, int64_t b); + friend MPInt &operator-=(MPInt &a, int64_t b); + friend MPInt &operator*=(MPInt &a, int64_t b); + friend MPInt &operator/=(MPInt &a, int64_t b); + friend MPInt &operator%=(MPInt &a, int64_t b); + + friend bool operator==(const MPInt &a, int64_t b); + friend bool operator!=(const MPInt &a, int64_t b); + friend bool operator>(const MPInt &a, int64_t b); + friend bool operator<(const MPInt &a, int64_t b); + friend bool operator<=(const MPInt &a, int64_t b); + friend bool operator>=(const MPInt &a, int64_t b); + friend MPInt operator+(const MPInt &a, int64_t b); + friend MPInt operator-(const MPInt &a, int64_t b); + friend MPInt operator*(const MPInt &a, int64_t b); + friend MPInt operator/(const MPInt &a, int64_t b); + friend MPInt operator%(const MPInt &a, int64_t b); + + friend bool operator==(int64_t a, const MPInt &b); + friend bool operator!=(int64_t a, const MPInt &b); + friend bool operator>(int64_t a, const MPInt &b); + friend bool operator<(int64_t a, const MPInt &b); + friend bool operator<=(int64_t a, const MPInt &b); + friend bool operator>=(int64_t a, const MPInt &b); + friend MPInt operator+(int64_t a, const MPInt &b); + friend MPInt operator-(int64_t a, const MPInt &b); + friend MPInt operator*(int64_t a, const MPInt &b); + friend MPInt operator/(int64_t a, const MPInt &b); + friend MPInt operator%(int64_t a, const MPInt &b); + + friend llvm::hash_code hash_value(const MPInt &x); // NOLINT +}; + +/// This just calls through to the operator int64_t, but it's useful when a +/// function pointer is required. (Although this is marked inline, it is still +/// possible to obtain and use a function pointer to this.) +__attribute__((always_inline)) inline int64_t int64FromMPInt(const MPInt &x) { + return int64_t(x); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const MPInt &x); + +// The RHS is always expected to be positive, and the result +/// is always non-negative. +MPInt mod(const MPInt &lhs, const MPInt &rhs); + +namespace detail { +// Division overflows only when trying to negate the minimal signed value. +__attribute((always_inline)) inline bool divWouldOverflow(int64_t x, + int64_t y) { + return x == std::numeric_limits::min() && y == -1; +} +} // namespace detail + +/// We define the operations here in the header to facilitate inlining. + +/// --------------------------------------------------------------------------- +/// Comparison operators. +/// --------------------------------------------------------------------------- +__attribute__((always_inline)) inline bool +MPInt::operator==(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() == o.getSmall(); + return detail::SlowMPInt(*this) == detail::SlowMPInt(o); +} +__attribute__((always_inline)) inline bool +MPInt::operator!=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() != o.getSmall(); + return detail::SlowMPInt(*this) != detail::SlowMPInt(o); +} +__attribute__((always_inline)) inline bool +MPInt::operator>(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() > o.getSmall(); + return detail::SlowMPInt(*this) > detail::SlowMPInt(o); +} +__attribute__((always_inline)) inline bool +MPInt::operator<(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() < o.getSmall(); + return detail::SlowMPInt(*this) < detail::SlowMPInt(o); +} +__attribute__((always_inline)) inline bool +MPInt::operator<=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() <= o.getSmall(); + return detail::SlowMPInt(*this) <= detail::SlowMPInt(o); +} +__attribute__((always_inline)) inline bool +MPInt::operator>=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() >= o.getSmall(); + return detail::SlowMPInt(*this) >= detail::SlowMPInt(o); +} + +/// --------------------------------------------------------------------------- +/// Arithmetic operators. +/// --------------------------------------------------------------------------- +__attribute__((always_inline)) inline MPInt +MPInt::operator+(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = + detail::addOverflow(getSmall(), o.getSmall(), result.getSmall()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(o)); + } + return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(o)); +} +__attribute__((always_inline)) inline MPInt +MPInt::operator-(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = + detail::subOverflow(getSmall(), o.getSmall(), result.getSmall()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(o)); + } + return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(o)); +} +__attribute__((always_inline)) inline MPInt +MPInt::operator*(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = + detail::mulOverflow(getSmall(), o.getSmall(), result.getSmall()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(o)); + } + return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(o)); +} + +// Division overflows only occur when negating the minimal possible value. +__attribute__((always_inline)) inline MPInt +MPInt::divByPositive(const MPInt &o) const { + assert(o > 0); + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return MPInt(getSmall() / o.getSmall()); + return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(o)); +} + +__attribute__((always_inline)) inline MPInt +MPInt::operator/(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + // Division overflows only occur when negating the minimal possible value. + if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), o.getSmall()))) + return -*this; + return MPInt(getSmall() / o.getSmall()); + } + return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(o)); +} + +inline MPInt abs(const MPInt &x) { return MPInt(x >= 0 ? x : -x); } +// Division overflows only occur when negating the minimal possible value. +inline MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) { + if (LLVM_UNLIKELY(detail::divWouldOverflow(lhs.getSmall(), rhs.getSmall()))) + return -lhs; + } + return MPInt(ceilDiv(detail::SlowMPInt(lhs), detail::SlowMPInt(rhs))); +} +inline MPInt floorDiv(const MPInt &lhs, const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) { + if (LLVM_UNLIKELY(detail::divWouldOverflow(lhs.getSmall(), rhs.getSmall()))) + return -lhs; + return MPInt(mlir::floorDiv(lhs.getSmall(), rhs.getSmall())); + } + return MPInt(floorDiv(detail::SlowMPInt(lhs), detail::SlowMPInt(rhs))); +} +// The RHS is always expected to be positive, and the result +/// is always non-negative. +inline MPInt mod(const MPInt &lhs, const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) + return MPInt(mlir::mod(lhs.getSmall(), rhs.getSmall())); + return MPInt(mod(detail::SlowMPInt(lhs), detail::SlowMPInt(rhs))); +} + +__attribute__((always_inline)) inline MPInt gcd(const MPInt &a, + const MPInt &b) { + if (LLVM_LIKELY(a.isSmall() && b.isSmall())) + return MPInt( + int64_t(llvm::greatestCommonDivisor(a.getSmall(), b.getSmall()))); + return MPInt(gcd(detail::SlowMPInt(a), detail::SlowMPInt(b))); +} + +/// Returns the least common multiple of 'a' and 'b'. +inline MPInt lcm(const MPInt &a, const MPInt &b) { + MPInt x = abs(a); + MPInt y = abs(b); + return (x * y) / gcd(x, y); +} + +/// This operation cannot overflow. +inline MPInt MPInt::operator%(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return MPInt(getSmall() % o.getSmall()); + return MPInt(detail::SlowMPInt(*this) % detail::SlowMPInt(o)); +} + +inline MPInt MPInt::operator-() const { + if (LLVM_LIKELY(isSmall())) { + if (LLVM_LIKELY(getSmall() != std::numeric_limits::min())) + return MPInt(-getSmall()); + return MPInt(-detail::SlowMPInt(*this)); + } + return MPInt(-detail::SlowMPInt(*this)); +} + +/// --------------------------------------------------------------------------- +/// Assignment operators, preincrement, predecrement. +/// --------------------------------------------------------------------------- +__attribute__((always_inline)) inline MPInt &MPInt::operator+=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = getSmall(); + bool overflow = detail::addOverflow(getSmall(), o.getSmall(), result); + if (LLVM_LIKELY(!overflow)) { + getSmall() = result; + return *this; + } + // Note: this return is not strictly required but + // removing it leads to a performance regression. + return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(o)); + } + return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(o)); +} +__attribute__((always_inline)) inline MPInt &MPInt::operator-=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = getSmall(); + bool overflow = detail::subOverflow(getSmall(), o.getSmall(), result); + if (LLVM_LIKELY(!overflow)) { + getSmall() = result; + return *this; + } + // Note: this return is not strictly required but + // removing it leads to a performance regression. + return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(o)); + } + return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(o)); +} +__attribute__((always_inline)) inline MPInt &MPInt::operator*=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = getSmall(); + bool overflow = detail::mulOverflow(getSmall(), o.getSmall(), result); + if (LLVM_LIKELY(!overflow)) { + getSmall() = result; + return *this; + } + // Note: this return is not strictly required but + // removing it leads to a performance regression. + return *this = MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(o)); + } + return *this = MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(o)); +} +__attribute__((always_inline)) inline MPInt &MPInt::operator/=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + // Division overflows only occur when negating the minimal possible value. + if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), o.getSmall()))) + return *this = -*this; + getSmall() /= o.getSmall(); + return *this; + } + return *this = MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(o)); +} + +// Division overflows only occur when the divisor is -1. +__attribute__((always_inline)) inline MPInt & +MPInt::divByPositiveInPlace(const MPInt &o) { + assert(o > 0); + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + getSmall() /= o.getSmall(); + return *this; + } + return *this = MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(o)); +} + +__attribute__((always_inline)) inline MPInt &MPInt::operator%=(const MPInt &o) { + return *this = *this % o; +} +__attribute__((always_inline)) inline MPInt &MPInt::operator++() { + return *this += 1; +} +__attribute__((always_inline)) inline MPInt &MPInt::operator--() { + return *this -= 1; +} + +/// ---------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// ---------------------------------------------------------------------------- +__attribute__((always_inline)) inline MPInt &operator+=(MPInt &a, int64_t b) { + return a = a + b; +} +__attribute__((always_inline)) inline MPInt &operator-=(MPInt &a, int64_t b) { + return a = a - b; +} +__attribute__((always_inline)) inline MPInt &operator*=(MPInt &a, int64_t b) { + return a = a * b; +} +__attribute__((always_inline)) inline MPInt &operator/=(MPInt &a, int64_t b) { + return a = a / b; +} +__attribute__((always_inline)) inline MPInt &operator%=(MPInt &a, int64_t b) { + return a = a % b; +} +inline MPInt operator+(const MPInt &a, int64_t b) { return a + MPInt(b); } +inline MPInt operator-(const MPInt &a, int64_t b) { return a - MPInt(b); } +inline MPInt operator*(const MPInt &a, int64_t b) { return a * MPInt(b); } +inline MPInt operator/(const MPInt &a, int64_t b) { return a / MPInt(b); } +inline MPInt operator%(const MPInt &a, int64_t b) { return a % MPInt(b); } +inline MPInt operator+(int64_t a, const MPInt &b) { return MPInt(a) + b; } +inline MPInt operator-(int64_t a, const MPInt &b) { return MPInt(a) - b; } +inline MPInt operator*(int64_t a, const MPInt &b) { return MPInt(a) * b; } +inline MPInt operator/(int64_t a, const MPInt &b) { return MPInt(a) / b; } +inline MPInt operator%(int64_t a, const MPInt &b) { return MPInt(a) % b; } + +/// We provide special implementations of the comparison operators rather than +/// calling through as above, as this would result in a 1.2x slowdown. +inline bool operator==(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() == b; + return a.getLarge() == b; +} +inline bool operator!=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() != b; + return a.getLarge() != b; +} +inline bool operator>(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() > b; + return a.getLarge() > b; +} +inline bool operator<(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() < b; + return a.getLarge() < b; +} +inline bool operator<=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() <= b; + return a.getLarge() <= b; +} +inline bool operator>=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() >= b; + return a.getLarge() >= b; +} +inline bool operator==(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a == b.getSmall(); + return a == b.getLarge(); +} +inline bool operator!=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a != b.getSmall(); + return a != b.getLarge(); +} +inline bool operator>(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a > b.getSmall(); + return a > b.getLarge(); +} +inline bool operator<(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a < b.getSmall(); + return a < b.getLarge(); +} +inline bool operator<=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a <= b.getSmall(); + return a <= b.getLarge(); +} +inline bool operator>=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a >= b.getSmall(); + return a >= b.getLarge(); +} + +} // namespace presburger +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_MPINT_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -2,6 +2,7 @@ IntegerRelation.cpp LinearTransform.cpp Matrix.cpp + MPInt.cpp PresburgerRelation.cpp PresburgerSpace.cpp PWMAFunction.cpp diff --git a/mlir/lib/Analysis/Presburger/MPInt.cpp b/mlir/lib/Analysis/Presburger/MPInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/MPInt.cpp @@ -0,0 +1,36 @@ +//===- MPInt.cpp - MLIR MPInt Class ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/MPInt.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace presburger; + +llvm::hash_code mlir::presburger::hash_value(const MPInt &x) { + if (x.isSmall()) + return llvm::hash_value(x.valSmall); + return detail::hash_value(x.valLarge); +} + +/// --------------------------------------------------------------------------- +/// Printing. +/// --------------------------------------------------------------------------- +llvm::raw_ostream &MPInt::print(llvm::raw_ostream &os) const { + if (isSmall()) + return os << valSmall; + return os << valLarge; +} + +void MPInt::dump() const { print(llvm::errs()); } + +llvm::raw_ostream &mlir::presburger::operator<<(llvm::raw_ostream &os, + const MPInt &x) { + x.print(os); + return os; +} diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -3,6 +3,7 @@ IntegerRelationTest.cpp LinearTransformTest.cpp MatrixTest.cpp + MPIntTest.cpp PresburgerSetTest.cpp PresburgerSpaceTest.cpp PWMAFunctionTest.cpp diff --git a/mlir/unittests/Analysis/Presburger/MPIntTest.cpp b/mlir/unittests/Analysis/Presburger/MPIntTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/MPIntTest.cpp @@ -0,0 +1,111 @@ +//===- MPIntTest.cpp - Tests for MPInt ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/MPInt.h" +#include +#include + +using namespace mlir; +using namespace presburger; + +TEST(MPIntTest, ops) { + MPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + five, ten); + EXPECT_EQ(five * five, 2 * ten + five); + EXPECT_EQ(five * five, 3 * ten - five); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(five / two, two); + EXPECT_EQ(five % two, two / two); + + EXPECT_EQ(-ten % seven, -10 % 7); + EXPECT_EQ(ten % -seven, 10 % -7); + EXPECT_EQ(-ten % -seven, -10 % -7); + EXPECT_EQ(ten % seven, 10 % 7); + + EXPECT_EQ(-ten / seven, -10 / 7); + EXPECT_EQ(ten / -seven, 10 / -7); + EXPECT_EQ(-ten / -seven, -10 / -7); + EXPECT_EQ(ten / seven, 10 / 7); + + MPInt x = ten; + x += five; + EXPECT_EQ(x, 15); + x *= two; + EXPECT_EQ(x, 30); + x /= seven; + EXPECT_EQ(x, 4); + x -= two * 10; + EXPECT_EQ(x, -16); + x *= 2 * two; + EXPECT_EQ(x, -64); + x /= two / -2; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, ten); + EXPECT_GE(ten, ten); + EXPECT_EQ(ten, ten); + EXPECT_FALSE(ten != ten); + EXPECT_FALSE(ten < ten); + EXPECT_FALSE(ten > ten); + EXPECT_LT(five, ten); + EXPECT_GT(ten, five); +} + +TEST(MPIntTest, ops64Overloads) { + MPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + 5, ten); + EXPECT_EQ(five + 5, 5 + five); + EXPECT_EQ(five * 5, 2 * ten + 5); + EXPECT_EQ(five * 5, 3 * ten - 5); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(5 / two, 2); + EXPECT_EQ(five / 2, 2); + EXPECT_EQ(2 % two, 0); + EXPECT_EQ(2 - two, 0); + EXPECT_EQ(2 % two, two % 2); + + MPInt x = ten; + x += 5; + EXPECT_EQ(x, 15); + x *= 2; + EXPECT_EQ(x, 30); + x /= 7; + EXPECT_EQ(x, 4); + x -= 20; + EXPECT_EQ(x, -16); + x *= 4; + EXPECT_EQ(x, -64); + x /= -1; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, 10); + EXPECT_GE(ten, 10); + EXPECT_EQ(ten, 10); + EXPECT_FALSE(ten != 10); + EXPECT_FALSE(ten < 10); + EXPECT_FALSE(ten > 10); + EXPECT_LT(five, 10); + EXPECT_GT(ten, 5); + + EXPECT_LE(10, ten); + EXPECT_GE(10, ten); + EXPECT_EQ(10, ten); + EXPECT_FALSE(10 != ten); + EXPECT_FALSE(10 < ten); + EXPECT_FALSE(10 > ten); + EXPECT_LT(5, ten); + EXPECT_GT(10, five); +} + +TEST(MPIntTest, overflows) { + MPInt x(1ll << 60); + EXPECT_EQ((x * x - x * x * x * x) / (x * x * x), 1 - (1ll << 60)); + MPInt y(1ll << 62); + EXPECT_EQ((y + y + y + y + y + y) / y, 6); + EXPECT_EQ(-(2 * (-y)), 2 * y); // -(-2^63) overflow. +}