diff --git a/libcxx/benchmarks/CMakeLists.txt b/libcxx/benchmarks/CMakeLists.txt --- a/libcxx/benchmarks/CMakeLists.txt +++ b/libcxx/benchmarks/CMakeLists.txt @@ -159,6 +159,7 @@ set(BENCHMARK_TESTS algorithms.partition_point.bench.cpp algorithms/equal.bench.cpp + algorithms/find.bench.cpp algorithms/lower_bound.bench.cpp algorithms/make_heap.bench.cpp algorithms/make_heap_then_sort_heap.bench.cpp diff --git a/libcxx/benchmarks/algorithms/find.bench.cpp b/libcxx/benchmarks/algorithms/find.bench.cpp new file mode 100644 --- /dev/null +++ b/libcxx/benchmarks/algorithms/find.bench.cpp @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +template +static void bm_find(benchmark::State& state) { + std::vector vec1(state.range(), '1'); + std::mt19937_64 rng(std::random_device{}()); + + for (auto _ : state) { + auto idx = rng() % vec1.size(); + vec1[idx] = '2'; + benchmark::DoNotOptimize(vec1); + benchmark::DoNotOptimize(std::find(vec1.begin(), vec1.end(), T('2'))); + vec1[idx] = '1'; + } +} +BENCHMARK(bm_find)->DenseRange(1, 8)->Range(16, 1 << 20); +BENCHMARK(bm_find)->DenseRange(1, 8)->Range(16, 1 << 20); +BENCHMARK(bm_find)->DenseRange(1, 8)->Range(16, 1 << 20); + +template +static void bm_ranges_find(benchmark::State& state) { + std::vector vec1(state.range(), '1'); + std::mt19937_64 rng(std::random_device{}()); + + for (auto _ : state) { + auto idx = rng() % vec1.size(); + vec1[idx] = '2'; + benchmark::DoNotOptimize(vec1); + benchmark::DoNotOptimize(std::ranges::find(vec1, T('2'))); + vec1[idx] = '1'; + } +} +BENCHMARK(bm_ranges_find)->DenseRange(1, 8)->Range(16, 1 << 20); +BENCHMARK(bm_ranges_find)->DenseRange(1, 8)->Range(16, 1 << 20); +BENCHMARK(bm_ranges_find)->DenseRange(1, 8)->Range(16, 1 << 20); + +BENCHMARK_MAIN(); diff --git a/libcxx/include/__algorithm/find.h b/libcxx/include/__algorithm/find.h --- a/libcxx/include/__algorithm/find.h +++ b/libcxx/include/__algorithm/find.h @@ -10,7 +10,13 @@ #ifndef _LIBCPP___ALGORITHM_FIND_H #define _LIBCPP___ALGORITHM_FIND_H +#include <__algorithm/unwrap_iter.h> #include <__config> +#include <__functional/identity.h> +#include <__functional/invoke.h> +#include <__string/constexpr_c_functions.h> +#include <__type_traits/is_same.h> +#include #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) # pragma GCC system_header @@ -18,15 +24,51 @@ _LIBCPP_BEGIN_NAMESPACE_STD -template -_LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 _InputIterator -find(_InputIterator __first, _InputIterator __last, const _Tp& __value) { +template +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Iter +__find_impl(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) { for (; __first != __last; ++__first) - if (*__first == __value) + if (std::__invoke(__proj, *__first) == __value) break; return __first; } +template ::value && __is_trivially_equality_comparable<_Tp, _Up>::value && + sizeof(_Tp) == 1, + int> = 0> +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* +__find_impl(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) { + if (auto __ret = std::__constexpr_memchr(__first, __value, __last - __first)) + return __ret; + return __last; +} + +#ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS +template ::value && __is_trivially_equality_comparable<_Tp, _Up>::value && + sizeof(_Tp) == sizeof(wchar_t) && alignof(_Tp) >= alignof(wchar_t), + int> = 0> +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* +__find_impl(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) { + if (auto __ret = std::__constexpr_wmemchr(__first, __value, __last - __first)) + return __ret; + return __last; +} +#endif // _LIBCPP_HAS_NO_WIDE_CHARACTERS + +template +_LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 _InputIterator +find(_InputIterator __first, _InputIterator __last, const _Tp& __value) { + __identity __proj; + return std::__rewrap_iter( + __first, std::__find_impl(std::__unwrap_iter(__first), std::__unwrap_iter(__last), __value, __proj)); +} + _LIBCPP_END_NAMESPACE_STD #endif // _LIBCPP___ALGORITHM_FIND_H diff --git a/libcxx/include/__algorithm/ranges_find.h b/libcxx/include/__algorithm/ranges_find.h --- a/libcxx/include/__algorithm/ranges_find.h +++ b/libcxx/include/__algorithm/ranges_find.h @@ -9,7 +9,9 @@ #ifndef _LIBCPP___ALGORITHM_RANGES_FIND_H #define _LIBCPP___ALGORITHM_RANGES_FIND_H +#include <__algorithm/find.h> #include <__algorithm/ranges_find_if.h> +#include <__algorithm/unwrap_range.h> #include <__config> #include <__functional/identity.h> #include <__functional/invoke.h> @@ -37,16 +39,20 @@ requires indirect_binary_predicate, const _Tp*> _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr _Ip operator()(_Ip __first, _Sp __last, const _Tp& __value, _Proj __proj = {}) const { - auto __pred = [&](auto&& __e) { return std::forward(__e) == __value; }; - return ranges::__find_if_impl(std::move(__first), std::move(__last), __pred, __proj); + if constexpr (forward_iterator<_Ip>) { + auto [__first_un, __last_un] = std::__unwrap_range(__first, std::move(__last)); + return std::__rewrap_range<_Sp>( + std::move(__first), std::__find_impl(std::move(__first_un), std::move(__last_un), __value, __proj)); + } else { + return std::__find_impl(std::move(__first), std::move(__last), __value, __proj); + } } template requires indirect_binary_predicate, _Proj>, const _Tp*> _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr borrowed_iterator_t<_Rp> operator()(_Rp&& __r, const _Tp& __value, _Proj __proj = {}) const { - auto __pred = [&](auto&& __e) { return std::forward(__e) == __value; }; - return ranges::__find_if_impl(ranges::begin(__r), ranges::end(__r), __pred, __proj); + return (*this)(ranges::begin(__r), ranges::end(__r), __value, std::move(__proj)); } }; } // namespace __find diff --git a/libcxx/include/__string/char_traits.h b/libcxx/include/__string/char_traits.h --- a/libcxx/include/__string/char_traits.h +++ b/libcxx/include/__string/char_traits.h @@ -224,7 +224,7 @@ const char_type* find(const char_type* __s, size_t __n, const char_type& __a) _NOEXCEPT { if (__n == 0) return nullptr; - return std::__constexpr_char_memchr(__s, static_cast(__a), __n); + return std::__constexpr_memchr(__s, static_cast(__a), __n); } static inline _LIBCPP_CONSTEXPR_SINCE_CXX20 diff --git a/libcxx/include/__string/constexpr_c_functions.h b/libcxx/include/__string/constexpr_c_functions.h --- a/libcxx/include/__string/constexpr_c_functions.h +++ b/libcxx/include/__string/constexpr_c_functions.h @@ -63,20 +63,27 @@ } } -inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 const char* -__constexpr_char_memchr(const char* __str, int __char, size_t __count) { -#if __has_builtin(__builtin_char_memchr) - return __builtin_char_memchr(__str, __char, __count); -#else - if (!__libcpp_is_constant_evaluated()) - return static_cast(__builtin_memchr(__str, __char, __count)); - for (; __count; --__count) { - if (*__str == __char) - return __str; - ++__str; - } - return nullptr; +template +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __constexpr_memchr(_Tp* __str, int __char, size_t __count) { + static_assert(sizeof(_Tp) == 1 && __is_trivially_equality_comparable<_Tp, _Tp>::value, + "Calling memchr on non-trivially equality comparable types is unsafe."); + + if (__libcpp_is_constant_evaluated()) { +// use __builtin_char_memchr to optimize constexpr evaluation if we can +#if _LIBCPP_STD_VER >= 17 && __has_builtin(__builtin_char_memchr) + if constexpr (is_same<_Tp, char>::value) + return __builtin_char_memchr(__str, __char, __count); #endif + + for (; __count; --__count) { + if (*__str == __char) + return __str; + ++__str; + } + return nullptr; + } else { + return static_cast<_Tp*>(__builtin_memchr(__str, __char, __count)); + } } _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/include/cwchar b/libcxx/include/cwchar --- a/libcxx/include/cwchar +++ b/libcxx/include/cwchar @@ -104,7 +104,11 @@ #include <__assert> // all public C++ headers provide the assertion handler #include <__config> +#include <__type_traits/apply_cv.h> #include <__type_traits/is_constant_evaluated.h> +#include <__type_traits/is_equality_comparable.h> +#include <__type_traits/is_same.h> +#include <__type_traits/remove_cv.h> #include #include @@ -222,21 +226,27 @@ #endif } -inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 const wchar_t* -__constexpr_wmemchr(const wchar_t* __str, wchar_t __char, size_t __count) { -#if __has_feature(cxx_constexpr_string_builtins) - return __builtin_wmemchr(__str, __char, __count); -#else - if (!__libcpp_is_constant_evaluated()) - return std::wmemchr(__str, __char, __count); +template +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* +__constexpr_wmemchr(_Tp* __str, wchar_t __char, size_t __count) { + static_assert(sizeof(_Tp) == sizeof(wchar_t) && __is_trivially_equality_comparable<_Tp, _Tp>::value, + "Calling wmemchr on non-trivially equality comparable types is unsafe."); - for (; __count; --__count) { - if (*__str == __char) - return __str; - ++__str; - } - return nullptr; + if (__libcpp_is_constant_evaluated()) { +#if __has_builtin(__builtin_wmemchr) + if constexpr (is_same<__remove_cv_t<_Tp>, wchar_t>::value) + return __builtin_wmemchr(__str, __char, __count); #endif + for (; __count; --__count) { + if (*__str == __char) + return __str; + ++__str; + } + return nullptr; + } + + return reinterpret_cast<_Tp*>( + std::wmemchr(reinterpret_cast<__apply_cv<_Tp, wchar_t>::type*>(__str), __char, __count)); } _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/test/libcxx/strings/c.strings/constexpr.cstring.compile.pass.cpp b/libcxx/test/libcxx/strings/c.strings/constexpr.cstring.compile.pass.cpp --- a/libcxx/test/libcxx/strings/c.strings/constexpr.cstring.compile.pass.cpp +++ b/libcxx/test/libcxx/strings/c.strings/constexpr.cstring.compile.pass.cpp @@ -21,6 +21,6 @@ constexpr bool test_constexpr_wmemchr() { const char str[] = "Banane"; - return std::__constexpr_char_memchr(str, 'n', 6) == str + 2; + return std::__constexpr_memchr(str, 'n', 6) == str + 2; } static_assert(test_constexpr_wmemchr(), ""); diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp --- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp +++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp @@ -15,32 +15,75 @@ #include #include +#include #include "test_macros.h" #include "test_iterators.h" +#include "type_algorithms.h" -#if TEST_STD_VER > 17 -TEST_CONSTEXPR bool test_constexpr() { - int ia[] = {1, 3, 5, 2, 4, 6}; - int ib[] = {1, 2, 3, 4, 5, 6}; - return (std::find(std::begin(ia), std::end(ia), 5) == ia+2) - && (std::find(std::begin(ib), std::end(ib), 9) == ib+6) - ; +template +struct Test { + template + TEST_CONSTEXPR_CXX20 void operator()() { + ArrayT arr[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + + static_assert(std::is_same::value, ""); + + { // first element matches + Iter iter = std::find(Iter(arr), Iter(arr + 10), CompareT(1)); + assert(*iter == 1); + assert(base(iter) == arr); } -#endif -int main(int, char**) -{ - int ia[] = {0, 1, 2, 3, 4, 5}; - const unsigned s = sizeof(ia)/sizeof(ia[0]); - cpp17_input_iterator r = std::find(cpp17_input_iterator(ia), - cpp17_input_iterator(ia+s), 3); - assert(*r == 3); - r = std::find(cpp17_input_iterator(ia), cpp17_input_iterator(ia+s), 10); - assert(r == cpp17_input_iterator(ia+s)); - -#if TEST_STD_VER > 17 - static_assert(test_constexpr()); + { // some element matches + Iter iter = std::find(Iter(arr), Iter(arr + 10), CompareT(6)); + assert(*iter == 6); + assert(base(iter) == arr + 5); + } + + { // last element matches + Iter iter = std::find(Iter(arr), Iter(arr + 10), CompareT(10)); + assert(*iter == 10); + assert(base(iter) == arr + 9); + } + + { // if no element matches, last is returned + Iter iter = std::find(Iter(arr), Iter(arr + 10), CompareT(20)); + assert(base(iter) == arr + 10); + } + } +}; + +class Comparable { + int i_; + +public: + TEST_CONSTEXPR Comparable(int i) : i_(i) {} + + TEST_CONSTEXPR friend bool operator==(const Comparable& lhs, long long rhs) { return lhs.i_ == rhs; } +}; + +template +struct TestTypes { + template + TEST_CONSTEXPR_CXX20 void operator()() { + types::for_each(types::cpp17_input_iterator_list(), Test()); + } +}; + +TEST_CONSTEXPR_CXX20 bool test() { + types::for_each(types::integer_types(), TestTypes()); + types::for_each(types::integer_types(), TestTypes()); + types::for_each(types::integer_types(), TestTypes()); + types::for_each(types::integer_types(), TestTypes()); + + return true; +} + +int main(int, char**) { + test(); +#if TEST_STD_VER >= 20 + static_assert(test()); #endif return 0; diff --git a/libcxx/test/support/type_algorithms.h b/libcxx/test/support/type_algorithms.h --- a/libcxx/test/support/type_algorithms.h +++ b/libcxx/test/support/type_algorithms.h @@ -95,7 +95,9 @@ #endif >; -using integral_types = concatenate_t >; +using integer_types = concatenate_t; + +using integral_types = concatenate_t >; using floating_point_types = type_list;