diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt --- a/libcxx/include/CMakeLists.txt +++ b/libcxx/include/CMakeLists.txt @@ -70,6 +70,7 @@ __algorithm/pop_heap.h __algorithm/prev_permutation.h __algorithm/pstl_any_all_none_of.h + __algorithm/pstl_find.h __algorithm/pstl_for_each.h __algorithm/push_heap.h __algorithm/ranges_adjacent_find.h diff --git a/libcxx/include/__algorithm/pstl_find.h b/libcxx/include/__algorithm/pstl_find.h new file mode 100644 --- /dev/null +++ b/libcxx/include/__algorithm/pstl_find.h @@ -0,0 +1,102 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _LIBCPP___ALGORITHM_PSTL_FIND_H +#define _LIBCPP___ALGORITHM_PSTL_FIND_H + +#include <__algorithm/comp.h> +#include <__algorithm/find.h> +#include <__config> +#include <__functional/not_fn.h> +#include <__pstl/internal/parallel_impl.h> +#include <__pstl/internal/unseq_backend_simd.h> +#include <__type_traits/is_execution_policy.h> +#include <__utility/terminate_on_exception.h> + +#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) +# pragma GCC system_header +#endif + +_LIBCPP_BEGIN_NAMESPACE_STD + +template >, int> = 0> +_LIBCPP_HIDE_FROM_ABI _ForwardIterator +find(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, const _Tp& __value) { + if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> && + __is_cpp17_random_access_iterator<_ForwardIterator>::value) { + return std::__terminate_on_exception([&] { + return __pstl::__internal::__parallel_find( + __pstl::__internal::__par_backend_tag{}, + __policy, + __first, + __last, + [&__policy, &__value](_ForwardIterator __brick_first, _ForwardIterator __brick_last) { + return std::find(std::__remove_parallel_policy(__policy), __brick_first, __brick_last, __value); + }, + less<>{}, + true); + }); + } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> && + __is_cpp17_random_access_iterator<_ForwardIterator>::value) { + using __diff_t = __iter_diff_t<_ForwardIterator>; + return __pstl::__unseq_backend::__simd_first( + __first, __diff_t(0), __last - __first, [&__value](_ForwardIterator __iter, __diff_t __i) { + return __iter[__i] == __value; + }); + } else { + return std::find(__first, __last, __value); + } +} + +template >, int> = 0> +_LIBCPP_HIDE_FROM_ABI _ForwardIterator +find_if(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) { + if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> && + __is_cpp17_random_access_iterator<_ForwardIterator>::value) { + return std::__terminate_on_exception([&] { + return __pstl::__internal::__parallel_find( + __pstl::__internal::__par_backend_tag{}, + __policy, + __first, + __last, + [&__policy, &__pred](_ForwardIterator __brick_first, _ForwardIterator __brick_last) { + return std::find_if(std::__remove_parallel_policy(__policy), __brick_first, __brick_last, __pred); + }, + less<>{}, + true); + }); + } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> && + __is_cpp17_random_access_iterator<_ForwardIterator>::value) { + using __diff_t = __iter_diff_t<_ForwardIterator>; + return __pstl::__unseq_backend::__simd_first( + __first, __diff_t(0), __last - __first, [&__pred](_ForwardIterator __iter, __diff_t __i) { + return __pred(__iter[__i]); + }); + } else { + return std::find_if(__first, __last, __pred); + } +} + +template >, int> = 0> +_LIBCPP_HIDE_FROM_ABI _ForwardIterator +find_if_not(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) { + return std::find_if(__policy, __first, __last, std::not_fn(std::move(__pred))); +} + +_LIBCPP_END_NAMESPACE_STD + +#endif // _LIBCPP___ALGORITHM_PSTL_FIND_H diff --git a/libcxx/include/__pstl/internal/algorithm_fwd.h b/libcxx/include/__pstl/internal/algorithm_fwd.h --- a/libcxx/include/__pstl/internal/algorithm_fwd.h +++ b/libcxx/include/__pstl/internal/algorithm_fwd.h @@ -322,31 +322,6 @@ _RandomAccessIterator2, _BinaryPredicate); -//------------------------------------------------------------------------ -// find_if -//------------------------------------------------------------------------ - -template -_ForwardIterator __brick_find_if( - _ForwardIterator, - _ForwardIterator, - _Predicate, - /*is_vector=*/std::false_type) noexcept; - -template -_RandomAccessIterator __brick_find_if( - _RandomAccessIterator, - _RandomAccessIterator, - _Predicate, - /*is_vector=*/std::true_type) noexcept; - -template -_ForwardIterator __pattern_find_if(_Tag, _ExecutionPolicy&&, _ForwardIterator, _ForwardIterator, _Predicate) noexcept; - -template -_RandomAccessIterator __pattern_find_if( - __parallel_tag<_IsVector>, _ExecutionPolicy&&, _RandomAccessIterator, _RandomAccessIterator, _Predicate); - //------------------------------------------------------------------------ // find_end //------------------------------------------------------------------------ diff --git a/libcxx/include/__pstl/internal/algorithm_impl.h b/libcxx/include/__pstl/internal/algorithm_impl.h --- a/libcxx/include/__pstl/internal/algorithm_impl.h +++ b/libcxx/include/__pstl/internal/algorithm_impl.h @@ -579,60 +579,6 @@ }); } -//------------------------------------------------------------------------ -// find_if -//------------------------------------------------------------------------ -template -_ForwardIterator -__brick_find_if(_ForwardIterator __first, - _ForwardIterator __last, - _Predicate __pred, - /*is_vector=*/std::false_type) noexcept { - return std::find_if(__first, __last, __pred); -} - -template -_RandomAccessIterator -__brick_find_if(_RandomAccessIterator __first, - _RandomAccessIterator __last, - _Predicate __pred, - /*is_vector=*/std::true_type) noexcept { - typedef typename std::iterator_traits<_RandomAccessIterator>::difference_type _SizeType; - return __unseq_backend::__simd_first( - __first, _SizeType(0), __last - __first, [&__pred](_RandomAccessIterator __it, _SizeType __i) { - return __pred(__it[__i]); - }); -} - -template -_ForwardIterator __pattern_find_if( - _Tag, _ExecutionPolicy&&, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) noexcept { - return __internal::__brick_find_if(__first, __last, __pred, typename _Tag::__is_vector{}); -} - -template -_RandomAccessIterator __pattern_find_if( - __parallel_tag<_IsVector> __tag, - _ExecutionPolicy&& __exec, - _RandomAccessIterator __first, - _RandomAccessIterator __last, - _Predicate __pred) { - using __backend_tag = typename decltype(__tag)::__backend_tag; - - return __internal::__except_handler([&]() { - return __internal::__parallel_find( - __backend_tag{}, - std::forward<_ExecutionPolicy>(__exec), - __first, - __last, - [__pred](_RandomAccessIterator __i, _RandomAccessIterator __j) { - return __internal::__brick_find_if(__i, __j, __pred, _IsVector{}); - }, - std::less::difference_type>(), - /*is_first=*/true); - }); -} - //------------------------------------------------------------------------ // find_end //------------------------------------------------------------------------ diff --git a/libcxx/include/__pstl/internal/glue_algorithm_defs.h b/libcxx/include/__pstl/internal/glue_algorithm_defs.h --- a/libcxx/include/__pstl/internal/glue_algorithm_defs.h +++ b/libcxx/include/__pstl/internal/glue_algorithm_defs.h @@ -20,20 +20,6 @@ namespace std { -// [alg.find] - -template -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator> -find_if(_ExecutionPolicy&& __exec, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred); - -template -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator> -find_if_not(_ExecutionPolicy&& __exec, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred); - -template -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator> -find(_ExecutionPolicy&& __exec, _ForwardIterator __first, _ForwardIterator __last, const _Tp& __value); - // [alg.find.end] template diff --git a/libcxx/include/__pstl/internal/glue_algorithm_impl.h b/libcxx/include/__pstl/internal/glue_algorithm_impl.h --- a/libcxx/include/__pstl/internal/glue_algorithm_impl.h +++ b/libcxx/include/__pstl/internal/glue_algorithm_impl.h @@ -25,30 +25,6 @@ namespace std { -// [alg.find] - -template -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator> -find_if(_ExecutionPolicy&& __exec, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) { - auto __dispatch_tag = __pstl::__internal::__select_backend(__exec, __first); - - return __pstl::__internal::__pattern_find_if( - __dispatch_tag, std::forward<_ExecutionPolicy>(__exec), __first, __last, __pred); -} - -template -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator> -find_if_not(_ExecutionPolicy&& __exec, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) { - return std::find_if(std::forward<_ExecutionPolicy>(__exec), __first, __last, std::not_fn(__pred)); -} - -template -__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator> -find(_ExecutionPolicy&& __exec, _ForwardIterator __first, _ForwardIterator __last, const _Tp& __value) { - return std::find_if( - std::forward<_ExecutionPolicy>(__exec), __first, __last, __pstl::__internal::__equal_value<_Tp>(__value)); -} - // [alg.find.end] template __pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator1> diff --git a/libcxx/include/algorithm b/libcxx/include/algorithm --- a/libcxx/include/algorithm +++ b/libcxx/include/algorithm @@ -1910,6 +1910,7 @@ #ifdef _LIBCPP_HAS_PARALLEL_ALGORITHMS # include <__algorithm/pstl_any_all_none_of.h> +# include <__algorithm/pstl_find.h> # include <__algorithm/pstl_for_each.h> #endif diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find.pass.cpp new file mode 100644 --- /dev/null +++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find.pass.cpp @@ -0,0 +1,82 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14 + +// REQUIRES: with-pstl + +// + +// template +// ForwardIterator find(ExecutionPolicy&& exec, ForwardIterator first, ForwardIterator last, +// const T& value); + +#include +#include +#include + +#include "test_macros.h" +#include "test_execution_policies.h" +#include "test_iterators.h" + +EXECUTION_POLICY_SFINAE_TEST(find); + +static_assert(sfinae_test_find); +static_assert(!sfinae_test_find); + +template +struct Test { + template + void operator()(Policy&& policy) { + int a[] = {1, 2, 3, 4, 5, 6, 7, 8}; + + // simple test + assert(base(std::find(policy, Iter(std::begin(a)), Iter(std::end(a)), 3)) == a + 2); + + // check that last is returned if no element matches + assert(base(std::find(policy, Iter(std::begin(a)), Iter(std::end(a)), 0)) == std::end(a)); + + // check that the first element is returned + assert(base(std::find(policy, Iter(std::begin(a)), Iter(std::end(a)), 1)) == std::begin(a)); + + // check that an empty range works + assert(base(std::find(policy, Iter(std::begin(a)), Iter(std::begin(a)), 1)) == std::begin(a)); + + // check that a one-element range works + assert(base(std::find(policy, Iter(std::begin(a)), Iter(std::begin(a) + 1), 1)) == std::begin(a)); + + // check that a two-element range works + assert(base(std::find(policy, Iter(std::begin(a)), Iter(std::begin(a) + 2), 2)) == std::begin(a) + 1); + + // check that a large number of elements works + std::vector vec(200, 4); + vec[176] = 5; + assert(base(std::find(policy, Iter(std::data(vec)), Iter(std::data(vec) + std::size(vec)), 5)) == + std::data(vec) + 176); + } +}; + +struct ThrowOnCompare {}; + +bool operator==(ThrowOnCompare, ThrowOnCompare) { throw int{}; } + +int main(int, char**) { + types::for_each(types::forward_iterator_list{}, TestIteratorWithPolicies{}); + +#ifndef TEST_HAS_NO_EXCEPTIONS + std::set_terminate(terminate_successful); + ThrowOnCompare a[2]; + try { + (void)std::find(std::execution::par, std::begin(a), std::end(a), ThrowOnCompare{}); + } catch (int) { + assert(false); + } +#endif + + return 0; +} diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find_if.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find_if.pass.cpp new file mode 100644 --- /dev/null +++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find_if.pass.cpp @@ -0,0 +1,84 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14 + +// REQUIRES: with-pstl + +// + +// template +// ForwardIterator find_if(ExecutionPolicy&& exec, ForwardIterator first, ForwardIterator last, +// Predicate pred); + +#include +#include +#include + +#include "test_macros.h" +#include "test_execution_policies.h" +#include "test_iterators.h" + +EXECUTION_POLICY_SFINAE_TEST(find_if); + +static_assert(sfinae_test_find_if); +static_assert(!sfinae_test_find_if); + +template +struct Test { + template + void operator()(Policy&& policy) { + int a[] = {1, 2, 3, 4, 5, 6, 7, 8}; + + // simple test + assert(base(std::find_if(policy, Iter(std::begin(a)), Iter(std::end(a)), [](int i) { return i == 3; })) == a + 2); + + // check that last is returned if no element matches + assert(base(std::find_if(policy, Iter(std::begin(a)), Iter(std::end(a)), [](int i) { return i == 0; })) == + std::end(a)); + + // check that the first element is returned + assert(base(std::find_if(policy, Iter(std::begin(a)), Iter(std::end(a)), [](int i) { return i == 1; })) == + std::begin(a)); + + // check that an empty range works + assert(base(std::find_if(policy, Iter(std::begin(a)), Iter(std::begin(a)), [](int i) { return i == 1; })) == + std::begin(a)); + + // check that a one-element range works + assert(base(std::find_if(policy, Iter(std::begin(a)), Iter(std::begin(a) + 1), [](int i) { return i == 1; })) == + std::begin(a)); + + // check that a two-element range works + assert(base(std::find_if(policy, Iter(std::begin(a)), Iter(std::begin(a) + 2), [](int i) { return i == 2; })) == + std::begin(a) + 1); + + // check that a large number of elements works + std::vector vec(200, 4); + vec[176] = 5; + assert(base(std::find_if(policy, Iter(std::data(vec)), Iter(std::data(vec) + std::size(vec)), [](int i) { + return i == 5; + })) == std::data(vec) + 176); + } +}; + +int main(int, char**) { + types::for_each(types::forward_iterator_list{}, TestIteratorWithPolicies{}); + +#ifndef TEST_HAS_NO_EXCEPTIONS + std::set_terminate(terminate_successful); + int a[] = {1, 2}; + try { + (void)std::find_if(std::execution::par, std::begin(a), std::end(a), [](int) -> bool { throw int{}; }); + } catch (int) { + assert(false); + } +#endif + + return 0; +} diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find_if_not.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find_if_not.pass.cpp new file mode 100644 --- /dev/null +++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/pstl.find_if_not.pass.cpp @@ -0,0 +1,85 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14 + +// REQUIRES: with-pstl + +// + +// template +// ForwardIterator find_if_not(ExecutionPolicy&& exec, ForwardIterator first, ForwardIterator last, +// Predicate pred); + +#include +#include +#include + +#include "test_macros.h" +#include "test_execution_policies.h" +#include "test_iterators.h" + +EXECUTION_POLICY_SFINAE_TEST(find_if_not); + +static_assert(sfinae_test_find_if_not); +static_assert(!sfinae_test_find_if_not); + +template +struct Test { + template + void operator()(Policy&& policy) { + int a[] = {1, 2, 3, 4, 5, 6, 7, 8}; + + // simple test + assert(base(std::find_if_not(policy, Iter(std::begin(a)), Iter(std::end(a)), [](int i) { return i != 3; })) == + a + 2); + + // check that last is returned if no element matches + assert(base(std::find_if_not(policy, Iter(std::begin(a)), Iter(std::end(a)), [](int i) { return i != 0; })) == + std::end(a)); + + // check that the first element is returned + assert(base(std::find_if_not(policy, Iter(std::begin(a)), Iter(std::end(a)), [](int i) { return i != 1; })) == + std::begin(a)); + + // check that an empty range works + assert(base(std::find_if_not(policy, Iter(std::begin(a)), Iter(std::begin(a)), [](int i) { return i != 1; })) == + std::begin(a)); + + // check that a one-element range works + assert(base(std::find_if_not(policy, Iter(std::begin(a)), Iter(std::begin(a) + 1), [](int i) { return i != 1; })) == + std::begin(a)); + + // check that a two-element range works + assert(base(std::find_if_not(policy, Iter(std::begin(a)), Iter(std::begin(a) + 2), [](int i) { return i != 2; })) == + std::begin(a) + 1); + + // check that a large number of elements works + std::vector vec(200, 4); + vec[176] = 5; + assert(base(std::find_if_not(policy, Iter(std::data(vec)), Iter(std::data(vec) + std::size(vec)), [](int i) { + return i != 5; + })) == std::data(vec) + 176); + } +}; + +int main(int, char**) { + types::for_each(types::forward_iterator_list{}, TestIteratorWithPolicies{}); + +#ifndef TEST_HAS_NO_EXCEPTIONS + std::set_terminate(terminate_successful); + int a[] = {1, 2}; + try { + (void)std::find_if_not(std::execution::par, std::begin(a), std::end(a), [](int) -> bool { throw int{}; }); + } catch (int) { + assert(false); + } +#endif + + return 0; +}