diff --git a/libcxx/include/__algorithm/pstl_copy.h b/libcxx/include/__algorithm/pstl_copy.h --- a/libcxx/include/__algorithm/pstl_copy.h +++ b/libcxx/include/__algorithm/pstl_copy.h @@ -10,6 +10,7 @@ #define _LIBCPP___ALGORITHM_PSTL_COPY_H #include <__algorithm/copy_n.h> +#include <__algorithm/pstl_frontend_dispatch.h> #include <__algorithm/pstl_transform.h> #include <__config> #include <__functional/identity.h> @@ -30,26 +31,48 @@ // TODO: Use the std::copy/move shenanigans to forward to std::memmove +template +void __pstl_copy(); + template >, int> = 0> + class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>, + enable_if_t, int> = 0> _LIBCPP_HIDE_FROM_ABI _ForwardOutIterator copy(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _ForwardOutIterator __result) { - return std::transform(__policy, __first, __last, __result, __identity()); + return std::__pstl_frontend_dispatch( + _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_copy), + [&__policy](_ForwardIterator __g_first, _ForwardIterator __g_last, _ForwardOutIterator __g_result) { + return std::transform(__policy, __g_first, __g_last, __g_result, __identity()); + }, + std::move(__first), + std::move(__last), + std::move(__result)); } +template +void __pstl_copy_n(); + template >, int> = 0> + class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>, + enable_if_t, int> = 0> _LIBCPP_HIDE_FROM_ABI _ForwardOutIterator copy_n(_ExecutionPolicy&& __policy, _ForwardIterator __first, _Size __n, _ForwardOutIterator __result) { - if constexpr (__has_random_access_iterator_category<_ForwardIterator>::value) - return std::copy(__policy, __first, __first + __n, __result); - else - return std::copy_n(__first, __n, __result); + return std::__pstl_frontend_dispatch( + _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_copy_n), + [&__policy](_ForwardIterator __g_first, _Size __g_n, _ForwardOutIterator __g_result) { + if constexpr (__has_random_access_iterator_category<_ForwardIterator>::value) + return std::copy(__policy, __g_first, __g_first + __g_n, __g_result); + else + return std::copy_n(__g_first, __g_n, __g_result); + }, + std::move(__first), + __n, + std::move(__result)); } _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp b/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp --- a/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp +++ b/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp @@ -42,6 +42,24 @@ return true; } +bool pstl_copy_called = false; + +template +ForwardIterator __pstl_copy(TestBackend, ForwardIterator, ForwardIterator, ForwardOutIterator) { + assert(!pstl_copy_called); + pstl_copy_called = true; + return 0; +} + +bool pstl_copy_n_called = false; + +template +ForwardIterator __pstl_copy_n(TestBackend, ForwardIterator, Size, ForwardOutIterator) { + assert(!pstl_copy_n_called); + pstl_copy_n_called = true; + return 0; +} + bool pstl_count_called = false; template @@ -274,6 +292,10 @@ assert(std::pstl_all_of_called); (void)std::none_of(TestPolicy{}, std::begin(a), std::end(a), pred); assert(std::pstl_none_of_called); + (void)std::copy(TestPolicy{}, std::begin(a), std::end(a), std::begin(a)); + assert(std::pstl_copy_called); + (void)std::copy_n(TestPolicy{}, std::begin(a), 1, std::begin(a)); + assert(std::pstl_copy_n_called); (void)std::count(TestPolicy{}, std::begin(a), std::end(a), 0); assert(std::pstl_count_called); (void)std::count_if(TestPolicy{}, std::begin(a), std::end(a), pred);