diff --git a/mlir/include/mlir/Analysis/Presburger/TPInt.h b/mlir/include/mlir/Analysis/Presburger/TPInt.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/TPInt.h @@ -0,0 +1,129 @@ +//===- TPInt.h - MLIR TPInt Class -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a simple class to represent arbitrary precision signed integers. +// Unlike APInt, one does not have to specify a fixed maximum size, and the +// integer can take on any aribtrary values. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_TPINT_H +#define MLIR_ANALYSIS_PRESBURGER_TPINT_H + +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace presburger { + +/// This class provides support for multi-precision arithmetic. +/// +/// Unlike APInt, this extends the precision as necessary to prevent overflows +/// and supports operations between objects with differing internal precisions. +/// +/// Since it uses APInt internally, TPInt (TransPrecision Int) stores values in +/// a 64-bit machine integer internally for small values and uses slower +/// arbitrary-precision arithmetic only for larger values. +class TPInt { +public: + explicit TPInt(int64_t val) : val(APSInt::get(val)) {} + TPInt() : TPInt(0) {} + explicit TPInt(const APSInt &val) : val(val) {} + TPInt &operator=(int64_t val) { return *this = TPInt(val); } + explicit operator int64_t() const { return val.getSExtValue(); } + TPInt operator-() const; + bool operator==(const TPInt &o) const; + bool operator!=(const TPInt &o) const; + bool operator>(const TPInt &o) const; + bool operator<(const TPInt &o) const; + bool operator<=(const TPInt &o) const; + bool operator>=(const TPInt &o) const; + TPInt operator+(const TPInt &o) const; + TPInt operator-(const TPInt &o) const; + TPInt operator*(const TPInt &o) const; + TPInt operator/(const TPInt &o) const; + TPInt operator%(const TPInt &o) const; + TPInt &operator+=(const TPInt &o); + TPInt &operator-=(const TPInt &o); + TPInt &operator*=(const TPInt &o); + TPInt &operator/=(const TPInt &o); + TPInt &operator%=(const TPInt &o); + + TPInt &operator++(); + TPInt &operator--(); + + friend TPInt abs(const TPInt &x); + friend TPInt ceilDiv(const TPInt &lhs, const TPInt &rhs); + friend TPInt floorDiv(const TPInt &lhs, const TPInt &rhs); + friend TPInt greatestCommonDivisor(const TPInt &a, const TPInt &b); + /// Overload to compute a hash_code for a TPInt value. + friend llvm::hash_code hash_value(const TPInt &x); // NOLINT + + llvm::raw_ostream &print(llvm::raw_ostream &os) const; + void dump() const; + +private: + unsigned getBitWidth() const { return val.getBitWidth(); } + + // The held integer value. + // + // TODO: consider using APInt directly to avoid unnecessary repeated internal + // signedness checks. This requires refactoring, exposing, or duplicating + // APSInt::compareValues. + APSInt val; +}; + +/// This just calls through to the operator int64_t, but it's useful when a +/// function pointer is required. +inline int64_t int64FromTPInt(const TPInt &x) { return int64_t(x); } + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const TPInt &x); + +// The RHS is always expected to be positive, and the result +/// is always non-negative. +TPInt mod(const TPInt &lhs, const TPInt &rhs); + +/// Returns the least common multiple of 'a' and 'b'. +TPInt lcm(const TPInt &a, const TPInt &b); + +/// Convenience overloads for 64-bit integers. +TPInt &operator+=(TPInt &a, int64_t b); +TPInt &operator-=(TPInt &a, int64_t b); +TPInt &operator*=(TPInt &a, int64_t b); +TPInt &operator/=(TPInt &a, int64_t b); +TPInt &operator%=(TPInt &a, int64_t b); + +bool operator==(const TPInt &a, int64_t b); +bool operator!=(const TPInt &a, int64_t b); +bool operator>(const TPInt &a, int64_t b); +bool operator<(const TPInt &a, int64_t b); +bool operator<=(const TPInt &a, int64_t b); +bool operator>=(const TPInt &a, int64_t b); +TPInt operator+(const TPInt &a, int64_t b); +TPInt operator-(const TPInt &a, int64_t b); +TPInt operator*(const TPInt &a, int64_t b); +TPInt operator/(const TPInt &a, int64_t b); +TPInt operator%(const TPInt &a, int64_t b); + +bool operator==(int64_t a, const TPInt &b); +bool operator!=(int64_t a, const TPInt &b); +bool operator>(int64_t a, const TPInt &b); +bool operator<(int64_t a, const TPInt &b); +bool operator<=(int64_t a, const TPInt &b); +bool operator>=(int64_t a, const TPInt &b); +TPInt operator+(int64_t a, const TPInt &b); +TPInt operator-(int64_t a, const TPInt &b); +TPInt operator*(int64_t a, const TPInt &b); +TPInt operator/(int64_t a, const TPInt &b); +TPInt operator%(int64_t a, const TPInt &b); + +} // namespace presburger +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_TPINT_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -6,6 +6,7 @@ PresburgerSpace.cpp PWMAFunction.cpp Simplex.cpp + TPInt.cpp Utils.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/Analysis/Presburger/TPInt.cpp b/mlir/lib/Analysis/Presburger/TPInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/TPInt.cpp @@ -0,0 +1,213 @@ +//===- TPInt.cpp - MLIR TPInt Class ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/TPInt.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace presburger; + +llvm::hash_code mlir::presburger::hash_value(const TPInt &x) { + return hash_value(x.val); +} + +/// --------------------------------------------------------------------------- +/// Printing. +/// --------------------------------------------------------------------------- +llvm::raw_ostream &TPInt::print(llvm::raw_ostream &os) const { + return os << val; +} + +void TPInt::dump() const { print(llvm::errs()); } + +llvm::raw_ostream &mlir::presburger::operator<<(llvm::raw_ostream &os, + const TPInt &x) { + x.print(os); + return os; +} + +/// --------------------------------------------------------------------------- +/// Comparison operators. +/// --------------------------------------------------------------------------- +bool TPInt::operator==(const TPInt &o) const { + return APSInt::compareValues(val, o.val) == 0; +} +bool TPInt::operator!=(const TPInt &o) const { + return APSInt::compareValues(val, o.val) != 0; +} +bool TPInt::operator>(const TPInt &o) const { + return APSInt::compareValues(val, o.val) > 0; +} +bool TPInt::operator<(const TPInt &o) const { + return APSInt::compareValues(val, o.val) < 0; +} +bool TPInt::operator<=(const TPInt &o) const { + return APSInt::compareValues(val, o.val) <= 0; +} +bool TPInt::operator>=(const TPInt &o) const { + return APSInt::compareValues(val, o.val) >= 0; +} + +/// --------------------------------------------------------------------------- +/// Arithmetic operators. +/// --------------------------------------------------------------------------- +using APIntOvOp = APInt (APInt::*)(const APInt &b, bool &overflow) const; + +/// Bring a and b to have the same width and then call a.op(b, overflow). +/// If the overflow bit becomes set, resize a and b to double the width and +/// call a.op(b, overflow), returning its result. The operation with double +/// widths should not also overflow. +APSInt doOpExpandIfOverflow(const APInt &a, const APInt &b, APIntOvOp op) { + bool overflow; + unsigned width = std::max(a.getBitWidth(), b.getBitWidth()); + // This calls a.sextOrSelf(width).op(b.sextOrSelf(width), overflow). + // TODO: in C++17 we can use the simpler syntax with std::invoke. + APInt ret = ((a.sextOrSelf(width)).*(op))(b.sextOrSelf(width), overflow); + if (!overflow) + return APSInt(ret, /*isUnsigned=*/false); + + width *= 2; + // This calls a.sextOrSelf(width).op(b.sextOrSelf(width), overflow). + ret = ((a.sextOrSelf(width)).*(op))(b.sextOrSelf(width), overflow); + assert(!overflow && "double width should be sufficient to avoid overflow!"); + return APSInt(ret, /*isUnsigned=*/false); +} + +TPInt TPInt::operator+(const TPInt &o) const { + return TPInt(doOpExpandIfOverflow(val, o.val, &APInt::sadd_ov)); +} +TPInt TPInt::operator-(const TPInt &o) const { + return TPInt(doOpExpandIfOverflow(val, o.val, &APInt::ssub_ov)); +} +TPInt TPInt::operator*(const TPInt &o) const { + return TPInt(doOpExpandIfOverflow(val, o.val, &APInt::smul_ov)); +} +TPInt TPInt::operator/(const TPInt &o) const { + return TPInt(doOpExpandIfOverflow(val, o.val, &APInt::sdiv_ov)); +} +namespace mlir { +namespace presburger { +using llvm::APIntOps::GreatestCommonDivisor; +using llvm::APIntOps::RoundingSDiv; +TPInt abs(const TPInt &x) { return x >= 0 ? x : -x; } +TPInt ceilDiv(const TPInt &lhs, const TPInt &rhs) { + if (rhs == -1) + return -lhs; + return TPInt(APSInt(RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::UP), + /*isUnsigned=*/false)); +} +TPInt floorDiv(const TPInt &lhs, const TPInt &rhs) { + if (rhs == -1) + return -lhs; + return TPInt(APSInt(RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::DOWN), + /*isUnsigned=*/false)); +} +// The RHS is always expected to be positive, and the result +/// is always non-negative. +TPInt mod(const TPInt &lhs, const TPInt &rhs) { + assert(rhs >= 1); + return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs; +} + +TPInt greatestCommonDivisor(const TPInt &a, const TPInt &b) { + return TPInt(APSInt(GreatestCommonDivisor(a.val.abs(), b.val.abs()), + /*isUnsigned=*/false)); +} + +/// Returns the least common multiple of 'a' and 'b'. +TPInt lcm(const TPInt &a, const TPInt &b) { + TPInt x = abs(a); + TPInt y = abs(b); + return (x * y) / greatestCommonDivisor(x, y); +} +} // namespace presburger +} // namespace mlir + +/// This operation cannot overflow. +TPInt TPInt::operator%(const TPInt &o) const { + unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth()); + return TPInt(APSInt(val.sextOrSelf(width).srem(o.val.sextOrSelf(width)), + /*isUnsigned=*/false)); +} + +TPInt TPInt::operator-() const { + if (val.isMinSignedValue()) { + /// Overflow only occurs when the values is the minimum possible value. + APSInt ret = val.extend(2 * val.getBitWidth()); + return TPInt(-ret); + } + return TPInt(-val); +} + +/// --------------------------------------------------------------------------- +/// Assignment operators, preincrement, predecrement. +/// --------------------------------------------------------------------------- +TPInt &TPInt::operator+=(const TPInt &o) { + *this = *this + o; + return *this; +} +TPInt &TPInt::operator-=(const TPInt &o) { + *this = *this - o; + return *this; +} +TPInt &TPInt::operator*=(const TPInt &o) { + *this = *this * o; + return *this; +} +TPInt &TPInt::operator/=(const TPInt &o) { + *this = *this / o; + return *this; +} +TPInt &TPInt::operator%=(const TPInt &o) { + *this = *this % o; + return *this; +} +TPInt &TPInt::operator++() { + *this += 1; + return *this; +} + +TPInt &TPInt::operator--() { + *this -= 1; + return *this; +} + +/// --------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// --------------------------------------------------------------------------- +namespace mlir { +namespace presburger { +TPInt &operator+=(TPInt &a, int64_t b) { return a += TPInt(b); } +TPInt &operator-=(TPInt &a, int64_t b) { return a -= TPInt(b); } +TPInt &operator*=(TPInt &a, int64_t b) { return a *= TPInt(b); } +TPInt &operator/=(TPInt &a, int64_t b) { return a /= TPInt(b); } +TPInt &operator%=(TPInt &a, int64_t b) { return a %= TPInt(b); } +bool operator==(const TPInt &a, int64_t b) { return a == TPInt(b); } +bool operator!=(const TPInt &a, int64_t b) { return a != TPInt(b); } +bool operator>(const TPInt &a, int64_t b) { return a > TPInt(b); } +bool operator<(const TPInt &a, int64_t b) { return a < TPInt(b); } +bool operator<=(const TPInt &a, int64_t b) { return a <= TPInt(b); } +bool operator>=(const TPInt &a, int64_t b) { return a >= TPInt(b); } +TPInt operator+(const TPInt &a, int64_t b) { return a + TPInt(b); } +TPInt operator-(const TPInt &a, int64_t b) { return a - TPInt(b); } +TPInt operator*(const TPInt &a, int64_t b) { return a * TPInt(b); } +TPInt operator/(const TPInt &a, int64_t b) { return a / TPInt(b); } +TPInt operator%(const TPInt &a, int64_t b) { return a % TPInt(b); } +bool operator==(int64_t a, const TPInt &b) { return TPInt(a) == b; } +bool operator!=(int64_t a, const TPInt &b) { return TPInt(a) != b; } +bool operator>(int64_t a, const TPInt &b) { return TPInt(a) > b; } +bool operator<(int64_t a, const TPInt &b) { return TPInt(a) < b; } +bool operator<=(int64_t a, const TPInt &b) { return TPInt(a) <= b; } +bool operator>=(int64_t a, const TPInt &b) { return TPInt(a) >= b; } +TPInt operator+(int64_t a, const TPInt &b) { return TPInt(a) + b; } +TPInt operator-(int64_t a, const TPInt &b) { return TPInt(a) - b; } +TPInt operator*(int64_t a, const TPInt &b) { return TPInt(a) * b; } +TPInt operator/(int64_t a, const TPInt &b) { return TPInt(a) / b; } +TPInt operator%(int64_t a, const TPInt &b) { return TPInt(a) % b; } +} // namespace presburger +} // namespace mlir diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -6,6 +6,7 @@ PresburgerSpaceTest.cpp PWMAFunctionTest.cpp SimplexTest.cpp + TPIntTest.cpp ../../Dialect/Affine/Analysis/AffineStructuresParser.cpp ) diff --git a/mlir/unittests/Analysis/Presburger/TPIntTest.cpp b/mlir/unittests/Analysis/Presburger/TPIntTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/TPIntTest.cpp @@ -0,0 +1,108 @@ +//===- TPIntTest.cpp - Tests for TPInt ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/TPInt.h" +#include +#include + +using namespace mlir; +using namespace presburger; + +TEST(TPIntTest, ops) { + TPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + five, ten); + EXPECT_EQ(five * five, 2 * ten + five); + EXPECT_EQ(five * five, 3 * ten - five); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(five / two, two); + EXPECT_EQ(five % two, two / two); + + EXPECT_EQ(-ten % seven, -10 % 7); + EXPECT_EQ(ten % -seven, 10 % -7); + EXPECT_EQ(-ten % -seven, -10 % -7); + EXPECT_EQ(ten % seven, 10 % 7); + + EXPECT_EQ(-ten / seven, -10 / 7); + EXPECT_EQ(ten / -seven, 10 / -7); + EXPECT_EQ(-ten / -seven, -10 / -7); + EXPECT_EQ(ten / seven, 10 / 7); + + TPInt x = ten; + x += five; + EXPECT_EQ(x, 15); + x *= two; + EXPECT_EQ(x, 30); + x /= seven; + EXPECT_EQ(x, 4); + x -= two * 10; + EXPECT_EQ(x, -16); + x *= 2 * two; + EXPECT_EQ(x, -64); + x /= two / -2; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, ten); + EXPECT_GE(ten, ten); + EXPECT_EQ(ten, ten); + EXPECT_FALSE(ten != ten); + EXPECT_FALSE(ten < ten); + EXPECT_FALSE(ten > ten); + EXPECT_LT(five, ten); + EXPECT_GT(ten, five); +} + +TEST(TPIntTest, ops64Overloads) { + TPInt two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + 5, ten); + EXPECT_EQ(five + 5, 5 + five); + EXPECT_EQ(five * 5, 2 * ten + 5); + EXPECT_EQ(five * 5, 3 * ten - 5); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(5 / two, 2); + EXPECT_EQ(five / 2, 2); + EXPECT_EQ(2 % two, 0); + EXPECT_EQ(2 - two, 0); + EXPECT_EQ(2 % two, two % 2); + + TPInt x = ten; + x += 5; + EXPECT_EQ(x, 15); + x *= 2; + EXPECT_EQ(x, 30); + x /= 7; + EXPECT_EQ(x, 4); + x -= 20; + EXPECT_EQ(x, -16); + x *= 4; + EXPECT_EQ(x, -64); + x /= -1; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, 10); + EXPECT_GE(ten, 10); + EXPECT_EQ(ten, 10); + EXPECT_FALSE(ten != 10); + EXPECT_FALSE(ten < 10); + EXPECT_FALSE(ten > 10); + EXPECT_LT(five, 10); + EXPECT_GT(ten, 5); + + EXPECT_LE(10, ten); + EXPECT_GE(10, ten); + EXPECT_EQ(10, ten); + EXPECT_FALSE(10 != ten); + EXPECT_FALSE(10 < ten); + EXPECT_FALSE(10 > ten); + EXPECT_LT(5, ten); + EXPECT_GT(10, five); +} + +TEST(TPIntTest, overflows) { + TPInt x(1ll << 60); + EXPECT_EQ((x * x - x * x * x * x) / (x * x * x), 1 - (1ll << 60)); +}