diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h @@ -0,0 +1,150 @@ +//===- ArithmeticUtils.h - Arithmetic helper functions ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header is not part of the public API. It is placed in the +// includes directory only because that's required by the implementations +// of template-classes. +// +// This file is part of the lightweight runtime support library for sparse +// tensor manipulations. The functionality of the support library is meant +// to simplify benchmarking, testing, and debugging MLIR code operating on +// sparse tensors. However, the provided functionality is **not** part of +// core MLIR itself. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H +#define MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H + +#include +#include +#include + +namespace mlir { +namespace sparse_tensor { +namespace detail { + +//===----------------------------------------------------------------------===// +// +// Safe comparison functions. +// +// Variants of the `==`, `!=`, `<`, `<=`, `>`, and `>=` operators which +// are careful to ensure that negatives are always considered strictly +// less than non-negatives regardless of the signedness of the types of +// the two arguments. They are "safe" in that they guarantee to *always* +// give an output and that that output is correct; in particular this means +// they never use assertions or other mechanisms for "returning an error". +// +// These functions are C++17-compatible backports of the safe comparison +// functions added in C++20, and the implementations are based on the +// sample implementations provided by the standard: +// . +// +//===----------------------------------------------------------------------===// + +template +constexpr bool safelyEQ(T t, U u) noexcept { + using UT = std::make_unsigned_t; + using UU = std::make_unsigned_t; + if constexpr (std::is_signed_v == std::is_signed_v) + return t == u; + else if constexpr (std::is_signed_v) + return t < 0 ? false : static_cast(t) == u; + else + return u < 0 ? false : t == static_cast(u); +} + +template +constexpr bool safelyNE(T t, U u) noexcept { + return !safelyEQ(t, u); +} + +template +constexpr bool safelyLT(T t, U u) noexcept { + using UT = std::make_unsigned_t; + using UU = std::make_unsigned_t; + if constexpr (std::is_signed_v == std::is_signed_v) + return t < u; + else if constexpr (std::is_signed_v) + return t < 0 ? true : static_cast(t) < u; + else + return u < 0 ? false : t < static_cast(u); +} + +template +constexpr bool safelyGT(T t, U u) noexcept { + return safelyLT(u, t); +} + +template +constexpr bool safelyLE(T t, U u) noexcept { + return !safelyGT(t, u); +} + +template +constexpr bool safelyGE(T t, U u) noexcept { + return !safelyLT(t, u); +} + +//===----------------------------------------------------------------------===// +// +// Overflow checking functions. +// +// These functions use assertions to ensure correctness with respect to +// overflow/underflow. Unlike the "safe" functions above, these "checked" +// functions only guarantee that *if* they return an answer then that answer +// is correct. When assertions are enabled, they do their best to remain +// as fast as possible (since MLIR keeps assertions enabled by default, +// even for optimized builds). When assertions are disabled, they use the +// standard unchecked implementations. +// +//===----------------------------------------------------------------------===// + +// TODO: we would like to be able to pass in custom error messages, to +// improve the user experience. We should be able to use something like +// `assert(((void)(msg ? msg : defaultMsg), cond))`; but I'm not entirely +// sure that'll work as intended when done within a function-definition +// rather than within a macro-definition. + +/// A version of `static_cast` which checks for overflow/underflow. +/// The implementation avoids performing runtime assertions whenever +/// the types alone are sufficient to statically prove that overflow +/// cannot happen. +template +[[nodiscard]] inline To checkOverflowCast(From x) { + // Check the lower bound. (For when casting from signed types.) + constexpr To minTo = std::numeric_limits::min(); + constexpr From minFrom = std::numeric_limits::min(); + if constexpr (!safelyGE(minFrom, minTo)) + assert(safelyGE(x, minTo) && "cast would underflow"); + // Check the upper bound. + constexpr To maxTo = std::numeric_limits::max(); + constexpr From maxFrom = std::numeric_limits::max(); + if constexpr (!safelyLE(maxFrom, maxTo)) + assert(safelyLE(x, maxTo) && "cast would overflow"); + // Now do the cast itself. + return static_cast(x); +} + +// TODO: would be better to use various architectures' intrinsics to +// detect the overflow directly, instead of doing the assertion beforehand +// (which requires an expensive division). +// +/// A version of `operator*` on `uint64_t` which guards against overflows +/// (when assertions are enabled). +inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) { + assert((lhs == 0 || rhs <= std::numeric_limits::max() / lhs) && + "Integer overflow"); + return lhs * rhs; +} + +} // namespace detail +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/CheckedMul.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/CheckedMul.h deleted file mode 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/CheckedMul.h +++ /dev/null @@ -1,48 +0,0 @@ -//===- CheckedMul.h - multiplication that checks for overflow ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header is not part of the public API. It is placed in the -// includes directory only because that's required by the implementations -// of template-classes. -// -// This file is part of the lightweight runtime support library for sparse -// tensor manipulations. The functionality of the support library is meant -// to simplify benchmarking, testing, and debugging MLIR code operating on -// sparse tensors. However, the provided functionality is **not** part of -// core MLIR itself. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_CHECKEDMUL_H -#define MLIR_EXECUTIONENGINE_SPARSETENSOR_CHECKEDMUL_H - -#include -#include -#include - -namespace mlir { -namespace sparse_tensor { -namespace detail { - -// TODO: would be better to use various architectures' intrinsics to -// detect the overflow directly, instead of doing the assertion beforehand -// (which requires an expensive division). -// -/// A version of `operator*` on `uint64_t` which guards against overflows -/// (when assertions are enabled). -inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) { - assert((lhs == 0 || rhs <= std::numeric_limits::max() / lhs) && - "Integer overflow"); - return lhs * rhs; -} - -} // namespace detail -} // namespace sparse_tensor -} // namespace mlir - -#endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_CHECKEDMUL_H diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -35,9 +35,9 @@ #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/ExecutionEngine/Float16bits.h" +#include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h" #include "mlir/ExecutionEngine/SparseTensor/Attributes.h" #include "mlir/ExecutionEngine/SparseTensor/COO.h" -#include "mlir/ExecutionEngine/SparseTensor/CheckedMul.h" #include "mlir/ExecutionEngine/SparseTensor/ErrorHandling.h" namespace mlir { @@ -509,9 +509,10 @@ /// the previous position and smaller than `indices[l].capacity()`). void appendPointer(uint64_t l, uint64_t pos, uint64_t count = 1) { ASSERT_COMPRESSED_LVL(l); - assert(pos <= std::numeric_limits

::max() && - "Pointer value is too large for the P-type"); - pointers[l].insert(pointers[l].end(), count, static_cast

(pos)); + // TODO: we'd like to recover the nicer error message: + // "Pointer value is too large for the P-type" + pointers[l].insert(pointers[l].end(), count, + detail::checkOverflowCast

(pos)); } /// Appends index `i` to level `l`, in the semantically general sense. @@ -526,9 +527,9 @@ void appendIndex(uint64_t l, uint64_t full, uint64_t i) { const auto dlt = getLvlType(l); // Avoid redundant bounds checking. if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { - assert(i <= std::numeric_limits::max() && - "Index value is too large for the I-type"); - indices[l].push_back(static_cast(i)); + // TODO: we'd like to recover the nicer error message: + // "Index value is too large for the I-type" + indices[l].push_back(detail::checkOverflowCast(i)); } else { // Dense dimension. ASSERT_DENSE_DLT(dlt); assert(i >= full && "Index was already filled"); @@ -551,9 +552,9 @@ // entry has been initialized; thus we must be sure to check `size()` // here, instead of `capacity()` as would be ideal. assert(pos < indices[l].size() && "Index position is out of bounds"); - assert(i <= std::numeric_limits::max() && - "Index value is too large for the I-type"); - indices[l][pos] = static_cast(i); + // TODO: we'd like to recover the nicer error message: + // "Index value is too large for the I-type" + indices[l][pos] = detail::checkOverflowCast(i); } /// Computes the assembled-size associated with the `l`-th level, diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6900,9 +6900,9 @@ "lib/ExecutionEngine/SparseTensor/Storage.cpp", ], hdrs = [ + "include/mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h", "include/mlir/ExecutionEngine/SparseTensor/Attributes.h", "include/mlir/ExecutionEngine/SparseTensor/COO.h", - "include/mlir/ExecutionEngine/SparseTensor/CheckedMul.h", "include/mlir/ExecutionEngine/SparseTensor/ErrorHandling.h", "include/mlir/ExecutionEngine/SparseTensor/File.h", "include/mlir/ExecutionEngine/SparseTensor/PermutationRef.h",