diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -77,6 +77,14 @@ swap(std::forward(lhs), std::forward(rhs)); } +using std::size; + +template +constexpr auto size_impl(RangeT &&range) + -> decltype(size(std::forward(range))) { + return size(std::forward(range)); +} + } // end namespace adl_detail /// Returns the begin iterator to \p range using `std::begin` and @@ -103,6 +111,14 @@ adl_detail::swap_impl(std::forward(lhs), std::forward(rhs)); } +/// Returns the size of \p range using `std::size` and functions found through +/// Argument-Dependent Lookup (ADL). +template +constexpr auto adl_size(RangeT &&range) + -> decltype(adl_detail::size_impl(std::forward(range))) { + return adl_detail::size_impl(std::forward(range)); +} + namespace detail { template @@ -745,6 +761,8 @@ template bool all_equal(std::initializer_list Values); +template constexpr size_t range_size(R &&Range); + namespace detail { using std::declval; @@ -936,9 +954,7 @@ template detail::zippy zip_equal(T &&t, U &&u, Args &&...args) { - assert(all_equal({std::distance(adl_begin(t), adl_end(t)), - std::distance(adl_begin(u), adl_end(u)), - std::distance(adl_begin(args), adl_end(args))...}) && + assert(all_equal({range_size(t), range_size(u), range_size(args)...}) && "Iteratees do not have equal length"); return detail::zippy( std::forward(t), std::forward(u), std::forward(args)...); @@ -951,9 +967,7 @@ template detail::zippy zip_first(T &&t, U &&u, Args &&...args) { - assert(std::distance(adl_begin(t), adl_end(t)) <= - std::min({std::distance(adl_begin(u), adl_end(u)), - std::distance(adl_begin(args), adl_end(args))...}) && + assert(range_size(t) <= std::min({range_size(u), range_size(args)...}) && "First iteratee is not the shortest"); return detail::zippy( @@ -1769,6 +1783,29 @@ return std::distance(Range.begin(), Range.end()); } +namespace detail { +template +using check_has_free_function_size = + decltype(adl_size(std::declval())); + +template +static constexpr bool HasFreeFunctionSize = + is_detected::value; +} // namespace detail + +/// Returns the size of the \p Range, i.e., the number of elements. This +/// implementation takes inspiration from `std::ranges::size` from C++20 and +/// delegates the size check to `adl_size` or `std::distance`, in this order of +/// preference. Unlike `llvm::size`, this function does *not* guarantee O(1) +/// running time, and is intended to be used in generic code that does not know +/// the exact range type. +template constexpr size_t range_size(R &&Range) { + if constexpr (detail::HasFreeFunctionSize) + return adl_size(Range); + else + return std::distance(adl_begin(Range), adl_end(Range)); +} + /// Provide wrappers to std::for_each which take ranges instead of having to /// pass begin/end explicitly. template @@ -2386,8 +2423,7 @@ template auto enumerate(FirstRange &&First, RestRanges &&...Rest) { assert((sizeof...(Rest) == 0 || - all_equal({std::distance(adl_begin(First), adl_end(First)), - std::distance(adl_begin(Rest), adl_end(Rest))...})) && + all_equal({range_size(First), range_size(Rest)...})) && "Ranges have different length"); using enumerator = detail::zippy; diff --git a/llvm/unittests/ADT/IteratorTest.cpp b/llvm/unittests/ADT/IteratorTest.cpp --- a/llvm/unittests/ADT/IteratorTest.cpp +++ b/llvm/unittests/ADT/IteratorTest.cpp @@ -743,4 +743,65 @@ EXPECT_EQ(std::distance(v2.begin(), v2.end()), size(v2)); } +TEST(RangeSizeTest, CommonRangeTypes) { + SmallVector v1 = {1, 2, 3}; + EXPECT_EQ(range_size(v1), 3u); + + std::map m1 = {{1, 1}, {2, 2}}; + EXPECT_EQ(range_size(m1), 2u); + + auto it_range = llvm::make_range(m1.begin(), m1.end()); + EXPECT_EQ(range_size(it_range), 2u); + + static constexpr int c_arr[5] = {}; + static_assert(range_size(c_arr) == 5u); + + static constexpr std::array cpp_arr = {}; + static_assert(range_size(cpp_arr) == 6u); +} + +struct FooWithMemberSize { + size_t size() const { return 42; } + auto begin() { return Data.begin(); } + auto end() { return Data.end(); } + + std::set Data; +}; + +TEST(RangeSizeTest, MemberSize) { + // Make sure that member `.size()` is preferred over the free fuction and + // `std::distance`. + FooWithMemberSize container; + EXPECT_EQ(range_size(container), 42u); +} + +struct FooWithFreeSize { + friend size_t size(const FooWithFreeSize &) { return 13; } + auto begin() { return Data.begin(); } + auto end() { return Data.end(); } + + std::set Data; +}; + +TEST(RangeSizeTest, FreeSize) { + // Make sure that `size(x)` is preferred over `std::distance`. + FooWithFreeSize container; + EXPECT_EQ(range_size(container), 13u); +} + +struct FooWithDistance { + auto begin() { return Data.begin(); } + auto end() { return Data.end(); } + + std::set Data; +}; + +TEST(RangeSizeTest, Distance) { + // Make sure that we can fall back to `std::distance` even the iterator is not + // random-access. + FooWithDistance container; + EXPECT_EQ(range_size(container), 0u); + container.Data = {1, 2, 3, 4}; + EXPECT_EQ(range_size(container), 4u); +} } // anonymous namespace diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -575,6 +575,39 @@ SUCCEED(); } +struct FooWithMemberSize { + size_t size() const { return 42; } + auto begin() { return Data.begin(); } + auto end() { return Data.end(); } + + std::set Data; +}; + +namespace some_namespace { +struct FooWithFreeSize { + auto begin() { return Data.begin(); } + auto end() { return Data.end(); } + + std::set Data; +}; + +size_t size(const FooWithFreeSize &) { return 13; } +} // namespace some_namespace + +TEST(STLExtrasTest, ADLSizeTest) { + FooWithMemberSize foo1; + EXPECT_EQ(adl_size(foo1), 42u); + + some_namespace::FooWithFreeSize foo2; + EXPECT_EQ(adl_size(foo2), 13u); + + static constexpr int c_arr[] = {1, 2, 3}; + static_assert(adl_size(c_arr) == 3u); + + static constexpr std::array cpp_arr = {}; + static_assert(adl_size(cpp_arr) == 4u); +} + TEST(STLExtrasTest, DropBeginTest) { SmallVector vec{0, 1, 2, 3, 4};