diff --git a/libcxx/include/__algorithm/includes.h b/libcxx/include/__algorithm/includes.h --- a/libcxx/include/__algorithm/includes.h +++ b/libcxx/include/__algorithm/includes.h @@ -12,7 +12,10 @@ #include <__algorithm/comp.h> #include <__algorithm/comp_ref_type.h> #include <__config> +#include <__functional/identity.h> +#include <__functional/invoke.h> #include <__iterator/iterator_traits.h> +#include <__type_traits/is_callable.h> #include <__utility/move.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -21,13 +24,15 @@ _LIBCPP_BEGIN_NAMESPACE_STD -template +template _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_AFTER_CXX17 bool -__includes(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Sent2 __last2, _Comp&& __comp) { +__includes(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Sent2 __last2, + _Comp&& __comp, _Proj1&& __proj1, _Proj2&& __proj2) { for (; __first2 != __last2; ++__first1) { - if (__first1 == __last1 || __comp(*__first2, *__first1)) + if (__first1 == __last1 || std::__invoke( + __comp, std::__invoke(__proj2, *__first2), std::__invoke(__proj1, *__first1))) return false; - if (!__comp(*__first1, *__first2)) + if (!std::__invoke(__comp, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2))) ++__first2; } return true; @@ -40,9 +45,13 @@ _InputIterator2 __first2, _InputIterator2 __last2, _Compare __comp) { + static_assert(__is_callable<_Compare, decltype(*__first1), decltype(*__first2)>::value, + "Comparator has to be callable"); + typedef typename __comp_ref_type<_Compare>::type _Comp_ref; return std::__includes( - std::move(__first1), std::move(__last1), std::move(__first2), std::move(__last2), static_cast<_Comp_ref>(__comp)); + std::move(__first1), std::move(__last1), std::move(__first2), std::move(__last2), + static_cast<_Comp_ref>(__comp), __identity(), __identity()); } template diff --git a/libcxx/include/__algorithm/ranges_includes.h b/libcxx/include/__algorithm/ranges_includes.h --- a/libcxx/include/__algorithm/ranges_includes.h +++ b/libcxx/include/__algorithm/ranges_includes.h @@ -56,7 +56,9 @@ std::move(__last1), std::move(__first2), std::move(__last2), - ranges::__make_projected_comp(__comp, __proj1, __proj2)); + std::move(__comp), + std::move(__proj1), + std::move(__proj2)); } template < @@ -73,7 +75,9 @@ ranges::end(__range1), ranges::begin(__range2), ranges::end(__range2), - ranges::__make_projected_comp(__comp, __proj1, __proj2)); + std::move(__comp), + std::move(__proj1), + std::move(__proj2)); } }; diff --git a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp new file mode 100644 --- /dev/null +++ b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp @@ -0,0 +1,86 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17 +// UNSUPPORTED: libcpp-has-no-incomplete-ranges + +// +// +// Range algorithms should support the case where the ranges they operate on have different value types and the given +// projection functors are different (each projection applies to a different value type). + +#include + +#include +#include +#include +#include + +// (in1, in2, ...) +template +constexpr void test(Func&& func, Input1& in1, Input2& in2, Args&& ...args) { + func(in1.begin(), in1.end(), in2.begin(), in2.end(), std::forward(args)...); + func(in1, in2, std::forward(args)...); +} + +constexpr bool test_all() { + struct A { + int x = 0; + + constexpr A() = default; + constexpr A(int value) : x(value) {} + constexpr operator int() const { return x; } + + constexpr auto operator<=>(const A&) const = default; + }; + + std::array in = {1, 2, 3}; + std::array in2 = {A{4}, A{5}, A{6}}; + + std::array output = {7, 8, 9, 10, 11, 12}; + auto out = output.begin(); + std::array output2 = {A{7}, A{8}, A{9}, A{10}, A{11}, A{12}}; + auto out2 = output2.begin(); + + std::ranges::equal_to eq; + std::ranges::less less; + auto sum = [](int lhs, A rhs) { return lhs + rhs.x; }; + auto proj1 = [](int x) { return x * -1; }; + auto proj2 = [](A a) { return a.x * -1; }; + + test(std::ranges::equal, in, in2, eq, proj1, proj2); + test(std::ranges::lexicographical_compare, in, in2, eq, proj1, proj2); + //test(std::ranges::is_permutation, in, in2, eq, proj1, proj2); + test(std::ranges::includes, in, in2, less, proj1, proj2); + test(std::ranges::find_first_of, in, in2, eq, proj1, proj2); + test(std::ranges::mismatch, in, in2, eq, proj1, proj2); + test(std::ranges::search, in, in2, eq, proj1, proj2); + test(std::ranges::find_end, in, in2, eq, proj1, proj2); + test(std::ranges::transform, in, in2, out, sum, proj1, proj2); + test(std::ranges::transform, in, in2, out2, sum, proj1, proj2); + //test(std::ranges::partial_sort_copy, in, in2, output2.begin(), output2.end(), less, proj1, proj2); + test(std::ranges::merge, in, in2, out, less, proj1, proj2); + test(std::ranges::merge, in, in2, out2, less, proj1, proj2); + test(std::ranges::set_intersection, in, in2, out, less, proj1, proj2); + test(std::ranges::set_intersection, in, in2, out2, less, proj1, proj2); + test(std::ranges::set_difference, in, in2, out, less, proj1, proj2); + test(std::ranges::set_difference, in, in2, out2, less, proj1, proj2); + test(std::ranges::set_symmetric_difference, in, in2, out, less, proj1, proj2); + test(std::ranges::set_symmetric_difference, in, in2, out2, less, proj1, proj2); + test(std::ranges::set_union, in, in2, out, less, proj1, proj2); + test(std::ranges::set_union, in, in2, out2, less, proj1, proj2); + + return true; +} + +int main(int, char**) { + test_all(); + static_assert(test_all()); + + return 0; +}