diff --git a/libcxx/include/__algorithm/binary_search.h b/libcxx/include/__algorithm/binary_search.h --- a/libcxx/include/__algorithm/binary_search.h +++ b/libcxx/include/__algorithm/binary_search.h @@ -11,8 +11,8 @@ #include <__config> #include <__algorithm/comp.h> -#include <__algorithm/lower_bound.h> #include <__algorithm/comp_ref_type.h> +#include <__algorithm/lower_bound.h> #include <__iterator/iterator_traits.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) diff --git a/libcxx/include/__algorithm/is_heap.h b/libcxx/include/__algorithm/is_heap.h --- a/libcxx/include/__algorithm/is_heap.h +++ b/libcxx/include/__algorithm/is_heap.h @@ -11,6 +11,7 @@ #include <__config> #include <__algorithm/comp.h> +#include <__algorithm/comp_ref_type.h> #include <__algorithm/is_heap_until.h> #include <__iterator/iterator_traits.h> @@ -26,7 +27,8 @@ bool is_heap(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare __comp) { - return _VSTD::is_heap_until(__first, __last, __comp) == __last; + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; + return _VSTD::__is_heap_until<_Comp_ref>(__first, __last, __comp) == __last; } template diff --git a/libcxx/include/__algorithm/is_heap_until.h b/libcxx/include/__algorithm/is_heap_until.h --- a/libcxx/include/__algorithm/is_heap_until.h +++ b/libcxx/include/__algorithm/is_heap_until.h @@ -11,6 +11,7 @@ #include <__config> #include <__algorithm/comp.h> +#include <__algorithm/comp_ref_type.h> #include <__iterator/iterator_traits.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -19,9 +20,9 @@ _LIBCPP_BEGIN_NAMESPACE_STD -template -_LIBCPP_NODISCARD_EXT _LIBCPP_CONSTEXPR_AFTER_CXX17 _RandomAccessIterator -is_heap_until(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare __comp) +template +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX17 _RandomAccessIterator +__is_heap_until(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare __comp) { typedef typename iterator_traits<_RandomAccessIterator>::difference_type difference_type; difference_type __len = __last - __first; @@ -46,13 +47,19 @@ return __last; } +template +_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX17 _RandomAccessIterator +is_heap_until(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare __comp) +{ + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; + return _VSTD::__is_heap_until<_Comp_ref>(__first, __last, __comp); +} + template -_LIBCPP_NODISCARD_EXT inline -_LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX17 -_RandomAccessIterator +_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX17 _RandomAccessIterator is_heap_until(_RandomAccessIterator __first, _RandomAccessIterator __last) { - return _VSTD::is_heap_until(__first, __last, __less::value_type>()); + return _VSTD::__is_heap_until(__first, __last, __less::value_type>()); } _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/include/__algorithm/is_sorted.h b/libcxx/include/__algorithm/is_sorted.h --- a/libcxx/include/__algorithm/is_sorted.h +++ b/libcxx/include/__algorithm/is_sorted.h @@ -10,6 +10,7 @@ #define _LIBCPP___ALGORITHM_IS_SORTED_H #include <__algorithm/comp.h> +#include <__algorithm/comp_ref_type.h> #include <__algorithm/is_sorted_until.h> #include <__config> #include <__iterator/iterator_traits.h> @@ -26,7 +27,8 @@ bool is_sorted(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) { - return _VSTD::is_sorted_until(__first, __last, __comp) == __last; + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; + return _VSTD::__is_sorted_until<_Comp_ref>(__first, __last, __comp) == __last; } template diff --git a/libcxx/include/__algorithm/is_sorted_until.h b/libcxx/include/__algorithm/is_sorted_until.h --- a/libcxx/include/__algorithm/is_sorted_until.h +++ b/libcxx/include/__algorithm/is_sorted_until.h @@ -11,6 +11,7 @@ #include <__config> #include <__algorithm/comp.h> +#include <__algorithm/comp_ref_type.h> #include <__iterator/iterator_traits.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -19,9 +20,9 @@ _LIBCPP_BEGIN_NAMESPACE_STD -template -_LIBCPP_NODISCARD_EXT _LIBCPP_CONSTEXPR_AFTER_CXX17 _ForwardIterator -is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) +template +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX17 _ForwardIterator +__is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) { if (__first != __last) { @@ -36,10 +37,16 @@ return __last; } +template +_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX17 _ForwardIterator +is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) +{ + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; + return _VSTD::__is_sorted_until<_Comp_ref>(__first, __last, __comp); +} + template -_LIBCPP_NODISCARD_EXT inline -_LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX17 -_ForwardIterator +_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX17 _ForwardIterator is_sorted_until(_ForwardIterator __first, _ForwardIterator __last) { return _VSTD::is_sorted_until(__first, __last, __less::value_type>()); diff --git a/libcxx/include/__algorithm/max.h b/libcxx/include/__algorithm/max.h --- a/libcxx/include/__algorithm/max.h +++ b/libcxx/include/__algorithm/max.h @@ -11,6 +11,7 @@ #include <__config> #include <__algorithm/comp.h> +#include <__algorithm/comp_ref_type.h> #include <__algorithm/max_element.h> #include @@ -49,7 +50,8 @@ _Tp max(initializer_list<_Tp> __t, _Compare __comp) { - return *_VSTD::max_element(__t.begin(), __t.end(), __comp); + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; + return *_VSTD::__max_element<_Comp_ref>(__t.begin(), __t.end(), __comp); } template diff --git a/libcxx/include/__algorithm/max_element.h b/libcxx/include/__algorithm/max_element.h --- a/libcxx/include/__algorithm/max_element.h +++ b/libcxx/include/__algorithm/max_element.h @@ -11,6 +11,7 @@ #include <__config> #include <__algorithm/comp.h> +#include <__algorithm/comp_ref_type.h> #include <__iterator/iterator_traits.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -19,11 +20,9 @@ _LIBCPP_BEGIN_NAMESPACE_STD -template -_LIBCPP_NODISCARD_EXT inline -_LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX11 -_ForwardIterator -max_element(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) +template +inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX11 _ForwardIterator +__max_element(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) { static_assert(__is_cpp17_forward_iterator<_ForwardIterator>::value, "std::max_element requires a ForwardIterator"); @@ -37,11 +36,17 @@ return __first; } +template +_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX11 _ForwardIterator +max_element(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) +{ + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; + return _VSTD::__max_element<_Comp_ref>(__first, __last, __comp); +} + template -_LIBCPP_NODISCARD_EXT inline -_LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX11 -_ForwardIterator +_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX11 _ForwardIterator max_element(_ForwardIterator __first, _ForwardIterator __last) { return _VSTD::max_element(__first, __last, diff --git a/libcxx/include/__algorithm/min.h b/libcxx/include/__algorithm/min.h --- a/libcxx/include/__algorithm/min.h +++ b/libcxx/include/__algorithm/min.h @@ -11,6 +11,7 @@ #include <__config> #include <__algorithm/comp.h> +#include <__algorithm/comp_ref_type.h> #include <__algorithm/min_element.h> #include @@ -49,7 +50,8 @@ _Tp min(initializer_list<_Tp> __t, _Compare __comp) { - return *_VSTD::min_element(__t.begin(), __t.end(), __comp); + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; + return *_VSTD::__min_element<_Comp_ref>(__t.begin(), __t.end(), __comp); } template diff --git a/libcxx/include/__algorithm/min_element.h b/libcxx/include/__algorithm/min_element.h --- a/libcxx/include/__algorithm/min_element.h +++ b/libcxx/include/__algorithm/min_element.h @@ -11,6 +11,7 @@ #include <__config> #include <__algorithm/comp.h> +#include <__algorithm/comp_ref_type.h> #include <__iterator/iterator_traits.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -19,11 +20,9 @@ _LIBCPP_BEGIN_NAMESPACE_STD -template -_LIBCPP_NODISCARD_EXT inline -_LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX11 -_ForwardIterator -min_element(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) +template +inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX11 _ForwardIterator +__min_element(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) { static_assert(__is_cpp17_forward_iterator<_ForwardIterator>::value, "std::min_element requires a ForwardIterator"); @@ -37,10 +36,16 @@ return __first; } +template +_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX11 _ForwardIterator +min_element(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) +{ + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; + return _VSTD::__min_element<_Comp_ref>(__first, __last, __comp); +} + template -_LIBCPP_NODISCARD_EXT inline -_LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX11 -_ForwardIterator +_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX11 _ForwardIterator min_element(_ForwardIterator __first, _ForwardIterator __last) { return _VSTD::min_element(__first, __last, diff --git a/libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp b/libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp new file mode 100644 --- /dev/null +++ b/libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp @@ -0,0 +1,189 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// + +#include +#include + +#include "test_macros.h" + +struct Less { + int *copies_; + TEST_CONSTEXPR explicit Less(int *copies) : copies_(copies) {} + TEST_CONSTEXPR Less(const Less& rhs) : copies_(rhs.copies_) { *copies_ += 1; } + TEST_CONSTEXPR Less& operator=(const Less&) = default; + TEST_CONSTEXPR bool operator()(void*, void*) const { return false; } +}; + +struct Equal { + int *copies_; + TEST_CONSTEXPR explicit Equal(int *copies) : copies_(copies) {} + TEST_CONSTEXPR Equal(const Equal& rhs) : copies_(rhs.copies_) { *copies_ += 1; } + TEST_CONSTEXPR Equal& operator=(const Equal&) = default; + TEST_CONSTEXPR bool operator()(void*, void*) const { return true; } +}; + +struct UnaryVoid { + int *copies_; + TEST_CONSTEXPR explicit UnaryVoid(int *copies) : copies_(copies) {} + TEST_CONSTEXPR UnaryVoid(const UnaryVoid& rhs) : copies_(rhs.copies_) { *copies_ += 1; } + TEST_CONSTEXPR UnaryVoid& operator=(const UnaryVoid&) = default; + TEST_CONSTEXPR_CXX14 void operator()(void*) const {} +}; + +struct UnaryTrue { + int *copies_; + TEST_CONSTEXPR explicit UnaryTrue(int *copies) : copies_(copies) {} + TEST_CONSTEXPR UnaryTrue(const UnaryTrue& rhs) : copies_(rhs.copies_) { *copies_ += 1; } + TEST_CONSTEXPR UnaryTrue& operator=(const UnaryTrue&) = default; + TEST_CONSTEXPR bool operator()(void*) const { return true; } +}; + +struct NullaryValue { + int *copies_; + TEST_CONSTEXPR explicit NullaryValue(int *copies) : copies_(copies) {} + TEST_CONSTEXPR NullaryValue(const NullaryValue& rhs) : copies_(rhs.copies_) { *copies_ += 1; } + TEST_CONSTEXPR NullaryValue& operator=(const NullaryValue&) = default; + TEST_CONSTEXPR decltype(nullptr) operator()() const { return nullptr; } +}; + +struct UnaryTransform { + int *copies_; + TEST_CONSTEXPR explicit UnaryTransform(int *copies) : copies_(copies) {} + TEST_CONSTEXPR UnaryTransform(const UnaryTransform& rhs) : copies_(rhs.copies_) { *copies_ += 1; } + TEST_CONSTEXPR UnaryTransform& operator=(const UnaryTransform&) = default; + TEST_CONSTEXPR decltype(nullptr) operator()(void*) const { return nullptr; } +}; + +struct BinaryTransform { + int *copies_; + TEST_CONSTEXPR explicit BinaryTransform(int *copies) : copies_(copies) {} + TEST_CONSTEXPR BinaryTransform(const BinaryTransform& rhs) : copies_(rhs.copies_) { *copies_ += 1; } + TEST_CONSTEXPR BinaryTransform& operator=(const BinaryTransform&) = default; + TEST_CONSTEXPR decltype(nullptr) operator()(void*, void*) const { return nullptr; } +}; + +TEST_CONSTEXPR_CXX20 bool all_the_algorithms() +{ + void *a[10] = {}; + void *b[10] = {}; + void **first = a; + void **mid = a+5; + void **last = a+10; + void **first2 = b; + void **mid2 = b+5; + void **last2 = b+10; + void *value = nullptr; + int count = 1; + + int copies = 0; + (void)std::adjacent_find(first, last, Equal(&copies)); assert(copies == 0); +#if TEST_STD_VER >= 11 + (void)std::all_of(first, last, UnaryTrue(&copies)); assert(copies == 0); + (void)std::any_of(first, last, UnaryTrue(&copies)); assert(copies == 0); +#endif + (void)std::binary_search(first, last, value, Less(&copies)); assert(copies == 0); +#if TEST_STD_VER > 17 + (void)std::clamp(value, value, value, Less(&copies)); assert(copies == 0); +#endif + (void)std::count_if(first, last, UnaryTrue(&copies)); assert(copies == 0); + (void)std::equal(first, last, first2, Equal(&copies)); assert(copies == 0); +#if TEST_STD_VER > 11 + (void)std::equal(first, last, first2, last2, Equal(&copies)); assert(copies == 0); +#endif + (void)std::equal_range(first, last, value, Less(&copies)); assert(copies == 0); + (void)std::find_end(first, last, first2, mid2, Equal(&copies)); assert(copies == 0); + (void)std::find_if(first, last, UnaryTrue(&copies)); assert(copies == 0); + (void)std::find_if_not(first, last, UnaryTrue(&copies)); assert(copies == 0); + (void)std::for_each(first, last, UnaryVoid(&copies)); assert(copies == 1); copies = 0; +#if TEST_STD_VER > 14 + (void)std::for_each_n(first, count, UnaryVoid(&copies)); assert(copies == 0); +#endif + (void)std::generate(first, last, NullaryValue(&copies)); assert(copies == 0); + (void)std::generate_n(first, count, NullaryValue(&copies)); assert(copies == 0); + (void)std::includes(first, last, first2, last2, Less(&copies)); assert(copies == 0); + (void)std::is_heap(first, last, Less(&copies)); assert(copies == 0); + (void)std::is_heap_until(first, last, Less(&copies)); assert(copies == 0); + (void)std::is_partitioned(first, last, UnaryTrue(&copies)); assert(copies == 0); + (void)std::is_permutation(first, last, first2, Equal(&copies)); assert(copies == 0); +#if TEST_STD_VER > 11 + (void)std::is_permutation(first, last, first2, last2, Equal(&copies)); assert(copies == 0); +#endif + (void)std::is_sorted(first, last, Less(&copies)); assert(copies == 0); + (void)std::is_sorted_until(first, last, Less(&copies)); assert(copies == 0); + if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::inplace_merge(first, mid, last, Less(&copies)); assert(copies == 0); } + (void)std::lexicographical_compare(first, last, first2, last2, Less(&copies)); assert(copies == 0); + // TODO: lexicographical_compare_three_way + (void)std::lower_bound(first, last, value, Less(&copies)); assert(copies == 0); + (void)std::make_heap(first, last, Less(&copies)); assert(copies == 0); + (void)std::max(value, value, Less(&copies)); assert(copies == 0); +#if TEST_STD_VER >= 11 + (void)std::max({ value, value }, Less(&copies)); assert(copies == 0); +#endif + (void)std::max_element(first, last, Less(&copies)); assert(copies == 0); + (void)std::merge(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0); + (void)std::min(value, value, Less(&copies)); assert(copies == 0); +#if TEST_STD_VER >= 11 + (void)std::min({ value, value }, Less(&copies)); assert(copies == 0); +#endif + (void)std::min_element(first, last, Less(&copies)); assert(copies == 0); + (void)std::minmax(value, value, Less(&copies)); assert(copies == 0); +#if TEST_STD_VER >= 11 + (void)std::minmax({ value, value }, Less(&copies)); assert(copies == 0); +#endif + (void)std::minmax_element(first, last, Less(&copies)); assert(copies == 0); + (void)std::mismatch(first, last, first2, Equal(&copies)); assert(copies == 0); +#if TEST_STD_VER > 11 + (void)std::mismatch(first, last, first2, last2, Equal(&copies)); assert(copies == 0); +#endif + (void)std::next_permutation(first, last, Less(&copies)); assert(copies == 0); +#if TEST_STD_VER >= 11 + (void)std::none_of(first, last, UnaryTrue(&copies)); assert(copies == 0); +#endif + (void)std::nth_element(first, mid, last, Less(&copies)); assert(copies == 0); + (void)std::partial_sort(first, mid, last, Less(&copies)); assert(copies == 0); + (void)std::partial_sort_copy(first, last, first2, mid2, Less(&copies)); assert(copies == 0); + (void)std::partition(first, last, UnaryTrue(&copies)); assert(copies == 0); + (void)std::partition_copy(first, last, first2, last2, UnaryTrue(&copies)); assert(copies == 0); + (void)std::partition_point(first, last, UnaryTrue(&copies)); assert(copies == 0); + (void)std::pop_heap(first, last, Less(&copies)); assert(copies == 0); + (void)std::prev_permutation(first, last, Less(&copies)); assert(copies == 0); + (void)std::push_heap(first, last, Less(&copies)); assert(copies == 0); + (void)std::remove_copy_if(first, last, first2, UnaryTrue(&copies)); assert(copies == 0); + (void)std::remove_if(first, last, UnaryTrue(&copies)); assert(copies == 0); + (void)std::replace_copy_if(first, last, first2, UnaryTrue(&copies), value); assert(copies == 0); + (void)std::replace_if(first, last, UnaryTrue(&copies), value); assert(copies == 0); + (void)std::search(first, last, first2, mid2, Equal(&copies)); assert(copies == 0); + (void)std::search_n(first, last, count, value, Equal(&copies)); assert(copies == 0); + (void)std::set_difference(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0); + (void)std::set_intersection(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0); + (void)std::set_symmetric_difference(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0); + (void)std::set_union(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0); + (void)std::sort(first, last, Less(&copies)); assert(copies == 0); + (void)std::sort_heap(first, last, Less(&copies)); assert(copies == 0); + if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_partition(first, last, UnaryTrue(&copies)); assert(copies == 0); } + if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_sort(first, last, Less(&copies)); assert(copies == 0); } + (void)std::transform(first, last, first2, UnaryTransform(&copies)); assert(copies == 0); + (void)std::transform(first, mid, mid, first2, BinaryTransform(&copies)); assert(copies == 0); + (void)std::unique(first, last, Equal(&copies)); assert(copies == 0); + (void)std::unique_copy(first, last, first2, Equal(&copies)); assert(copies == 0); + (void)std::upper_bound(first, last, value, Less(&copies)); assert(copies == 0); + + return true; +} + +int main(int, char**) +{ + all_the_algorithms(); +#if TEST_STD_VER > 17 + static_assert(all_the_algorithms()); +#endif + + return 0; +}