diff --git a/mlir/include/mlir/Analysis/Presburger/SlowMPInt.h b/mlir/include/mlir/Analysis/Presburger/SlowMPInt.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/SlowMPInt.h @@ -0,0 +1,134 @@ +//===- SlowMPInt.h - MLIR SlowMPInt Class -------------------------------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a simple class to represent arbitrary precision signed integers. +// Unlike APInt, one does not have to specify a fixed maximum size, and the +// integer can take on any arbitrary values. +// +// This class is to be used as a fallback slow path for the +// soon-to-be-implemented MPInt class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H +#define MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H + +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace presburger { +namespace detail { + +// A simple class providing multi-precision arithmetic. Internally, it stores +// an APInt, whose width it keeps expanding as necessary. This is primarily +// intended to be used as a slow fallback path for the upcoming MPInt class. +class SlowMPInt { +private: + llvm::APInt val; + +public: + explicit SlowMPInt(int64_t val); + SlowMPInt(); + explicit SlowMPInt(const llvm::APInt &val); + SlowMPInt &operator=(int64_t val); + explicit operator int64_t() const; + SlowMPInt operator-() const; + bool operator==(const SlowMPInt &o) const; + bool operator!=(const SlowMPInt &o) const; + bool operator>(const SlowMPInt &o) const; + bool operator<(const SlowMPInt &o) const; + bool operator<=(const SlowMPInt &o) const; + bool operator>=(const SlowMPInt &o) const; + SlowMPInt operator+(const SlowMPInt &o) const; + SlowMPInt operator-(const SlowMPInt &o) const; + SlowMPInt operator*(const SlowMPInt &o) const; + SlowMPInt operator/(const SlowMPInt &o) const; + SlowMPInt operator%(const SlowMPInt &o) const; + SlowMPInt &operator+=(const SlowMPInt &o); + SlowMPInt &operator-=(const SlowMPInt &o); + SlowMPInt &operator*=(const SlowMPInt &o); + SlowMPInt &operator/=(const SlowMPInt &o); + SlowMPInt &operator%=(const SlowMPInt &o); + + SlowMPInt &operator++(); + SlowMPInt &operator--(); + + friend SlowMPInt abs(const SlowMPInt &x); + friend SlowMPInt ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs); + friend SlowMPInt floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs); + friend SlowMPInt gcd(const SlowMPInt &a, const SlowMPInt &b); + + /// Overload to compute a hash_code for a SlowMPInt value. + friend llvm::hash_code hash_value(const SlowMPInt &x); // NOLINT + + llvm::raw_ostream &print(llvm::raw_ostream &os) const; + void dump() const; + + unsigned getBitWidth() const { return val.getBitWidth(); } +}; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const SlowMPInt &x); + +/// Returns the remainder of dividing LHS by RHS. +/// +/// The RHS is always expected to be positive, and the result +/// is always non-negative. +SlowMPInt mod(const SlowMPInt &lhs, const SlowMPInt &rhs); + +/// Returns the least common multiple of 'a' and 'b'. +SlowMPInt lcm(const SlowMPInt &a, const SlowMPInt &b); + +/// Redeclarations of friend declarations above to +/// make it discoverable by lookups. +SlowMPInt abs(const SlowMPInt &x); +SlowMPInt ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs); +SlowMPInt floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs); +SlowMPInt gcd(const SlowMPInt &a, const SlowMPInt &b); +llvm::hash_code hash_value(const SlowMPInt &x); // NOLINT + +/// --------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// --------------------------------------------------------------------------- +SlowMPInt &operator+=(SlowMPInt &a, int64_t b); +SlowMPInt &operator-=(SlowMPInt &a, int64_t b); +SlowMPInt &operator*=(SlowMPInt &a, int64_t b); +SlowMPInt &operator/=(SlowMPInt &a, int64_t b); +SlowMPInt &operator%=(SlowMPInt &a, int64_t b); + +bool operator==(const SlowMPInt &a, int64_t b); +bool operator!=(const SlowMPInt &a, int64_t b); +bool operator>(const SlowMPInt &a, int64_t b); +bool operator<(const SlowMPInt &a, int64_t b); +bool operator<=(const SlowMPInt &a, int64_t b); +bool operator>=(const SlowMPInt &a, int64_t b); +SlowMPInt operator+(const SlowMPInt &a, int64_t b); +SlowMPInt operator-(const SlowMPInt &a, int64_t b); +SlowMPInt operator*(const SlowMPInt &a, int64_t b); +SlowMPInt operator/(const SlowMPInt &a, int64_t b); +SlowMPInt operator%(const SlowMPInt &a, int64_t b); + +bool operator==(int64_t a, const SlowMPInt &b); +bool operator!=(int64_t a, const SlowMPInt &b); +bool operator>(int64_t a, const SlowMPInt &b); +bool operator<(int64_t a, const SlowMPInt &b); +bool operator<=(int64_t a, const SlowMPInt &b); +bool operator>=(int64_t a, const SlowMPInt &b); +SlowMPInt operator+(int64_t a, const SlowMPInt &b); +SlowMPInt operator-(int64_t a, const SlowMPInt &b); +SlowMPInt operator*(int64_t a, const SlowMPInt &b); +SlowMPInt operator/(int64_t a, const SlowMPInt &b); +SlowMPInt operator%(int64_t a, const SlowMPInt &b); +} // namespace detail +} // namespace presburger +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -6,6 +6,7 @@ PresburgerSpace.cpp PWMAFunction.cpp Simplex.cpp + SlowMPInt.cpp Utils.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/Analysis/Presburger/SlowMPInt.cpp b/mlir/lib/Analysis/Presburger/SlowMPInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/SlowMPInt.cpp @@ -0,0 +1,227 @@ +//===- MPInt.cpp - MLIR MPInt Class ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/SlowMPInt.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace presburger { +namespace detail { + +SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {} +SlowMPInt::SlowMPInt() : SlowMPInt(0) {} +SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {} +SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); } +SlowMPInt::operator int64_t() const { return val.getSExtValue(); } + +llvm::hash_code hash_value(const SlowMPInt &x) { return hash_value(x.val); } + +/// --------------------------------------------------------------------------- +/// Printing. +/// --------------------------------------------------------------------------- +llvm::raw_ostream &SlowMPInt::print(llvm::raw_ostream &os) const { + return os << val; +} + +void SlowMPInt::dump() const { print(llvm::errs()); } + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const SlowMPInt &x) { + x.print(os); + return os; +} + +/// --------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// --------------------------------------------------------------------------- +SlowMPInt &operator+=(SlowMPInt &a, int64_t b) { return a += SlowMPInt(b); } +SlowMPInt &operator-=(SlowMPInt &a, int64_t b) { return a -= SlowMPInt(b); } +SlowMPInt &operator*=(SlowMPInt &a, int64_t b) { return a *= SlowMPInt(b); } +SlowMPInt &operator/=(SlowMPInt &a, int64_t b) { return a /= SlowMPInt(b); } +SlowMPInt &operator%=(SlowMPInt &a, int64_t b) { return a %= SlowMPInt(b); } + +bool operator==(const SlowMPInt &a, int64_t b) { return a == SlowMPInt(b); } +bool operator!=(const SlowMPInt &a, int64_t b) { return a != SlowMPInt(b); } +bool operator>(const SlowMPInt &a, int64_t b) { return a > SlowMPInt(b); } +bool operator<(const SlowMPInt &a, int64_t b) { return a < SlowMPInt(b); } +bool operator<=(const SlowMPInt &a, int64_t b) { return a <= SlowMPInt(b); } +bool operator>=(const SlowMPInt &a, int64_t b) { return a >= SlowMPInt(b); } +SlowMPInt operator+(const SlowMPInt &a, int64_t b) { return a + SlowMPInt(b); } +SlowMPInt operator-(const SlowMPInt &a, int64_t b) { return a - SlowMPInt(b); } +SlowMPInt operator*(const SlowMPInt &a, int64_t b) { return a * SlowMPInt(b); } +SlowMPInt operator/(const SlowMPInt &a, int64_t b) { return a / SlowMPInt(b); } +SlowMPInt operator%(const SlowMPInt &a, int64_t b) { return a % SlowMPInt(b); } + +bool operator==(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) == b; } +bool operator!=(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) != b; } +bool operator>(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) > b; } +bool operator<(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) < b; } +bool operator<=(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) <= b; } +bool operator>=(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) >= b; } +SlowMPInt operator+(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) + b; } +SlowMPInt operator-(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) - b; } +SlowMPInt operator*(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) * b; } +SlowMPInt operator/(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) / b; } +SlowMPInt operator%(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) % b; } + +static unsigned getMaxWidth(const APInt &a, const APInt &b) { + return std::max(a.getBitWidth(), b.getBitWidth()); +} + +/// --------------------------------------------------------------------------- +/// Comparison operators. +/// --------------------------------------------------------------------------- + +// TODO: consider instead making APInt::compare available and using that. +bool SlowMPInt::operator==(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width) == o.val.sext(width); +} +bool SlowMPInt::operator!=(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width) != o.val.sext(width); +} +bool SlowMPInt::operator>(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width).sgt(o.val.sext(width)); +} +bool SlowMPInt::operator<(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width).slt(o.val.sext(width)); +} +bool SlowMPInt::operator<=(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width).sle(o.val.sext(width)); +} +bool SlowMPInt::operator>=(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width).sge(o.val.sext(width)); +} + +/// --------------------------------------------------------------------------- +/// Arithmetic operators. +/// --------------------------------------------------------------------------- + +/// Bring a and b to have the same width and then call a.op(b, overflow). +/// If the overflow bit becomes set, resize a and b to double the width and +/// call a.op(b, overflow), returning its result. The operation with double +/// widths should not also overflow. +template +APInt runOpWithExpandOnOverflow(const APInt &a, const APInt &b, + const Function &op) { + bool overflow; + unsigned width = getMaxWidth(a, b); + ; + APInt ret = op(a.sext(width), b.sext(width), overflow); + if (!overflow) + return ret; + + width *= 2; + ret = op(a.sext(width), b.sext(width), overflow); + assert(!overflow && "double width should be sufficient to avoid overflow!"); + return ret; +} + +SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const { + return SlowMPInt( + runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov))); +} +SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const { + return SlowMPInt( + runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov))); +} +SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const { + return SlowMPInt( + runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov))); +} +SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const { + return SlowMPInt( + runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov))); +} +SlowMPInt abs(const SlowMPInt &x) { return x >= 0 ? x : -x; } +SlowMPInt ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { + if (rhs == -1) + return -lhs; + return SlowMPInt( + llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::UP)); +} +SlowMPInt floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { + if (rhs == -1) + return -lhs; + return SlowMPInt( + llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::DOWN)); +} +// The RHS is always expected to be positive, and the result +/// is always non-negative. +SlowMPInt mod(const SlowMPInt &lhs, const SlowMPInt &rhs) { + assert(rhs >= 1 && "mod is only supported for positive divisors!"); + return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs; +} + +SlowMPInt gcd(const SlowMPInt &a, const SlowMPInt &b) { + return SlowMPInt( + llvm::APIntOps::GreatestCommonDivisor(a.val.abs(), b.val.abs())); +} + +/// Returns the least common multiple of 'a' and 'b'. +SlowMPInt lcm(const SlowMPInt &a, const SlowMPInt &b) { + SlowMPInt x = abs(a); + SlowMPInt y = abs(b); + return (x * y) / gcd(x, y); +} + +/// This operation cannot overflow. +SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const { + unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth()); + return SlowMPInt(val.sext(width).srem(o.val.sext(width))); +} + +SlowMPInt SlowMPInt::operator-() const { + if (val.isMinSignedValue()) { + /// Overflow only occurs when the value is the minimum possible value. + APInt ret = val.sext(2 * val.getBitWidth()); + return SlowMPInt(-ret); + } + return SlowMPInt(-val); +} + +/// --------------------------------------------------------------------------- +/// Assignment operators, preincrement, predecrement. +/// --------------------------------------------------------------------------- +SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) { + *this = *this + o; + return *this; +} +SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) { + *this = *this - o; + return *this; +} +SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) { + *this = *this * o; + return *this; +} +SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) { + *this = *this / o; + return *this; +} +SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) { + *this = *this % o; + return *this; +} +SlowMPInt &SlowMPInt::operator++() { + *this += 1; + return *this; +} + +SlowMPInt &SlowMPInt::operator--() { + *this -= 1; + return *this; +} + +} // namespace detail +} // namespace presburger +} // namespace mlir diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -7,6 +7,7 @@ PresburgerSpaceTest.cpp PWMAFunctionTest.cpp SimplexTest.cpp + SlowMPIntTest.cpp ../../Dialect/Affine/Analysis/AffineStructuresParser.cpp ) diff --git a/mlir/unittests/Analysis/Presburger/SlowMPIntTest.cpp b/mlir/unittests/Analysis/Presburger/SlowMPIntTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/SlowMPIntTest.cpp @@ -0,0 +1,109 @@ +//===- SlowMPIntTest.cpp - Tests for SlowMPInt ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/SlowMPInt.h" +#include +#include + +using namespace mlir; +using namespace presburger; +using detail::SlowMPInt; + +TEST(SlowMPIntTest, ops) { + SlowMPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + five, ten); + EXPECT_EQ(five * five, 2 * ten + five); + EXPECT_EQ(five * five, 3 * ten - five); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(five / two, two); + EXPECT_EQ(five % two, two / two); + + EXPECT_EQ(-ten % seven, -10 % 7); + EXPECT_EQ(ten % -seven, 10 % -7); + EXPECT_EQ(-ten % -seven, -10 % -7); + EXPECT_EQ(ten % seven, 10 % 7); + + EXPECT_EQ(-ten / seven, -10 / 7); + EXPECT_EQ(ten / -seven, 10 / -7); + EXPECT_EQ(-ten / -seven, -10 / -7); + EXPECT_EQ(ten / seven, 10 / 7); + + SlowMPInt x = ten; + x += five; + EXPECT_EQ(x, 15); + x *= two; + EXPECT_EQ(x, 30); + x /= seven; + EXPECT_EQ(x, 4); + x -= two * 10; + EXPECT_EQ(x, -16); + x *= 2 * two; + EXPECT_EQ(x, -64); + x /= two / -2; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, ten); + EXPECT_GE(ten, ten); + EXPECT_EQ(ten, ten); + EXPECT_FALSE(ten != ten); + EXPECT_FALSE(ten < ten); + EXPECT_FALSE(ten > ten); + EXPECT_LT(five, ten); + EXPECT_GT(ten, five); +} + +TEST(SlowMPIntTest, ops64Overloads) { + SlowMPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + 5, ten); + EXPECT_EQ(five + 5, 5 + five); + EXPECT_EQ(five * 5, 2 * ten + 5); + EXPECT_EQ(five * 5, 3 * ten - 5); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(5 / two, 2); + EXPECT_EQ(five / 2, 2); + EXPECT_EQ(2 % two, 0); + EXPECT_EQ(2 - two, 0); + EXPECT_EQ(2 % two, two % 2); + + SlowMPInt x = ten; + x += 5; + EXPECT_EQ(x, 15); + x *= 2; + EXPECT_EQ(x, 30); + x /= 7; + EXPECT_EQ(x, 4); + x -= 20; + EXPECT_EQ(x, -16); + x *= 4; + EXPECT_EQ(x, -64); + x /= -1; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, 10); + EXPECT_GE(ten, 10); + EXPECT_EQ(ten, 10); + EXPECT_FALSE(ten != 10); + EXPECT_FALSE(ten < 10); + EXPECT_FALSE(ten > 10); + EXPECT_LT(five, 10); + EXPECT_GT(ten, 5); + + EXPECT_LE(10, ten); + EXPECT_GE(10, ten); + EXPECT_EQ(10, ten); + EXPECT_FALSE(10 != ten); + EXPECT_FALSE(10 < ten); + EXPECT_FALSE(10 > ten); + EXPECT_LT(5, ten); + EXPECT_GT(10, five); +} + +TEST(SlowMPIntTest, overflows) { + SlowMPInt x(1ll << 60); + EXPECT_EQ((x * x - x * x * x * x) / (x * x * x), 1 - (1ll << 60)); +}