Index: libcxx/include/experimental/simd =================================================================== --- libcxx/include/experimental/simd +++ libcxx/include/experimental/simd @@ -594,6 +594,7 @@ #include #include #include +#include #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) #pragma GCC system_header @@ -1634,33 +1635,48 @@ } // reductions [simd.reductions] -template > -_Tp reduce(const simd<_Tp, _Abi>& __v, _BinaryOp __op = _BinaryOp()) { - _Tp __acc = __v[0]; +template +typename _SimdType::value_type __reduce(const _SimdType& __v, _BinaryOp __op) { + auto __acc = __v[0]; for (size_t __i = 1; __i < __v.size(); __i++) { __acc = __op(__acc, __v[__i]); } return __acc; } -template -_Tp hmin(const simd<_Tp, _Abi>& __v) { - _Tp __acc = __v[0]; +template +typename _SimdType::value_type __hmin(const _SimdType& __v) { + auto __acc = __v[0]; for (size_t __i = 1; __i < __v.size(); __i++) { __acc = __acc > __v[__i] ? __v[__i] : __acc; } return __acc; } -template -_Tp hmax(const simd<_Tp, _Abi>& __v) { - _Tp __acc = __v[0]; +template +typename _SimdType::value_type __hmax(const _SimdType& __v) { + auto __acc = __v[0]; for (size_t __i = 1; __i < __v.size(); __i++) { __acc = __acc < __v[__i] ? __v[__i] : __acc; } return __acc; } +template > +_Tp reduce(const simd<_Tp, _Abi>& __v, _BinaryOp __op = _BinaryOp()) { + return __reduce(__v, __op); +} + +template +_Tp hmin(const simd<_Tp, _Abi>& __v) { + return __hmin(__v); +} + +template +_Tp hmax(const simd<_Tp, _Abi>& __v) { + return __hmax(__v); +} + // algorithms [simd.alg] template simd<_Tp, _Abi> min(const simd<_Tp, _Abi>& __a, @@ -2295,6 +2311,19 @@ const_where_expression>::type where(_Mp __m, const _Up& __v) noexcept; + template + friend typename _Vp::value_type + reduce(const const_where_expression<_Mp, _Vp>& __w, + typename _Vp::value_type __identity, _BinaryOp __op); + + template + friend typename _Vp::value_type + hmin(const const_where_expression<_Mp, _Vp>& __w); + + template + friend typename _Vp::value_type + hmax(const const_where_expression<_Mp, _Vp>& __w); + public: const_where_expression& operator=(const const_where_expression&) = delete; @@ -2479,41 +2508,63 @@ template typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - typename _SimdType::value_type neutral_element, _BinaryOp binary_op); +reduce(const const_where_expression<_MaskType, _SimdType>& __w, + typename _SimdType::value_type __identity, _BinaryOp __op) { + auto __v = __w.__v_; + where(!__w.__m_, __v) = _SimdType(__identity); + return __reduce(__v, __op); +} template typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - plus binary_op = {}); +reduce(const const_where_expression<_MaskType, _SimdType>& __w, + plus __op = {}) { + return reduce(__w, typename _SimdType::value_type(0), __op); +} template typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - multiplies binary_op); +reduce(const const_where_expression<_MaskType, _SimdType>& __w, + multiplies __op) { + return reduce(__w, typename _SimdType::value_type(1), __op); +} template typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - bit_and binary_op); +reduce(const const_where_expression<_MaskType, _SimdType>& __w, + bit_and __op) { + return reduce(__w, typename _SimdType::value_type(-1), __op); +} template typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - bit_or binary_op); +reduce(const const_where_expression<_MaskType, _SimdType>& __w, + bit_or __op) { + return reduce(__w, typename _SimdType::value_type(0), __op); +} template typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - bit_xor binary_op); +reduce(const const_where_expression<_MaskType, _SimdType>& __w, + bit_xor __op) { + return reduce(__w, typename _SimdType::value_type(0), __op); +} template typename _SimdType::value_type -hmin(const const_where_expression<_MaskType, _SimdType>&); +hmin(const const_where_expression<_MaskType, _SimdType>& __w) { + return __hmin(__simd_mask_friend::__simd_select( + _SimdType(std::numeric_limits::max()), + __w.__v_, __w.__m_)); +} template typename _SimdType::value_type -hmax(const const_where_expression<_MaskType, _SimdType>&); +hmax(const const_where_expression<_MaskType, _SimdType>& __w) { + return __hmax(__simd_mask_friend::__simd_select( + _SimdType(std::numeric_limits::lowest()), + __w.__v_, __w.__m_)); +} _LIBCPP_END_NAMESPACE_EXPERIMENTAL_SIMD Index: libcxx/test/std/experimental/simd/simd.horizontal/hmax.pass.cpp =================================================================== --- libcxx/test/std/experimental/simd/simd.horizontal/hmax.pass.cpp +++ libcxx/test/std/experimental/simd/simd.horizontal/hmax.pass.cpp @@ -20,7 +20,7 @@ using namespace std::experimental::parallelism_v2; -void test_hmax() { +void test_hmax_simd() { { int a[] = {2, 5, -4, 6}; assert(hmax(fixed_size_simd(a, element_aligned_tag())) == 6); @@ -39,4 +39,34 @@ } } -int main() { test_hmax(); } +void test_hmax_mask() { + assert(hmax(where(native_simd_mask(false), native_simd())) == + std::numeric_limits::min()); + { + int buffer[] = {2, 5, -4, 6}; + fixed_size_simd a(buffer, element_aligned_tag()); + assert(hmax(where(a <= 6, a)) == 6); + assert(hmax(where(a < 6, a)) == 5); + assert(hmax(where(a < 5, a)) == 2); + assert(hmax(where(a < 2, a)) == -4); + assert(hmax(where(a < -4, a)) == std::numeric_limits::min()); + } + { + bool buffer[] = {false, true, true, false}; + fixed_size_simd_mask a(buffer, element_aligned_tag()); + assert(hmax(where(fixed_size_simd_mask(true), a)) == true); + assert(hmax(where(!a, a)) == false); + } + + { + const fixed_size_simd a(0); + assert(hmax(where(fixed_size_simd_mask(true), a)) == 0.f); + assert(hmax(where(fixed_size_simd_mask(false), a)) == + std::numeric_limits::lowest()); + } +} + +int main() { + test_hmax_simd(); + test_hmax_mask(); +} Index: libcxx/test/std/experimental/simd/simd.horizontal/hmin.pass.cpp =================================================================== --- libcxx/test/std/experimental/simd/simd.horizontal/hmin.pass.cpp +++ libcxx/test/std/experimental/simd/simd.horizontal/hmin.pass.cpp @@ -20,7 +20,7 @@ using namespace std::experimental::parallelism_v2; -void test_hmin() { +void test_hmin_simd() { { int a[] = {2, 5, -4, 6}; assert(hmin(fixed_size_simd(a, element_aligned_tag())) == -4); @@ -39,4 +39,27 @@ } } -int main() { test_hmin(); } +void test_hmin_mask() { + assert(hmin(where(native_simd_mask(false), native_simd())) == + std::numeric_limits::max()); + { + int buffer[] = {2, 5, -4, 6}; + fixed_size_simd a(buffer, element_aligned_tag()); + assert(hmin(where(a >= -4, a)) == -4); + assert(hmin(where(a > -4, a)) == 2); + assert(hmin(where(a > 2, a)) == 5); + assert(hmin(where(a > 5, a)) == 6); + assert(hmin(where(a > 6, a)) == std::numeric_limits::max()); + } + { + bool buffer[] = {false, true, true, false}; + fixed_size_simd_mask a(buffer, element_aligned_tag()); + assert(hmin(where(fixed_size_simd_mask(true), a)) == false); + assert(hmin(where(a, a)) == true); + } +} + +int main() { + test_hmin_simd(); + test_hmin_mask(); +} Index: libcxx/test/std/experimental/simd/simd.horizontal/reduce.pass.cpp =================================================================== --- libcxx/test/std/experimental/simd/simd.horizontal/reduce.pass.cpp +++ libcxx/test/std/experimental/simd/simd.horizontal/reduce.pass.cpp @@ -42,7 +42,7 @@ inline int factorial(int n) { return n == 1 ? 1 : n * factorial(n - 1); } -void test_reduce() { +void test_reduce_simd() { int n = (int)native_simd::size(); assert(reduce(native_simd([](int i) { return i; })) == n * (n - 1) / 2); assert(reduce(native_simd([](int i) { return i; }), std::plus()) == @@ -51,4 +51,39 @@ std::multiplies()) == factorial(n)); } -int main() { test_reduce(); } +void test_reduce_mask() { + { + fixed_size_simd a([](int i) { return i; }); + assert(reduce(where(a < 2, a), 0, std::plus()) == 0 + 1); + assert(reduce(where(a >= 2, a), 1, std::multiplies()) == 2 * 3); + assert(reduce(where(a >= 2, a)) == 2 + 3); + assert(reduce(where(a >= 2, a), std::plus()) == 2 + 3); + assert(reduce(where(a >= 2, a), std::multiplies()) == 2 * 3); + assert(reduce(where(a >= 2, a), std::bit_and()) == (2 & 3)); + assert(reduce(where(a >= 2, a), std::bit_or()) == (2 | 3)); + assert(reduce(where(a >= 2, a), std::bit_xor()) == (2 ^ 3)); + } + { + fixed_size_simd_mask a; + a[0] = false; + a[1] = true; + a[2] = true; + a[3] = false; + assert(reduce(where(fixed_size_simd_mask(true), a)) == true); + assert(reduce(where(fixed_size_simd_mask(true), a), + std::plus()) == true); + assert(reduce(where(fixed_size_simd_mask(true), a), + std::multiplies()) == false); + assert(reduce(where(fixed_size_simd_mask(true), a), + std::bit_and()) == false); + assert(reduce(where(fixed_size_simd_mask(true), a), + std::bit_or()) == true); + assert(reduce(where(fixed_size_simd_mask(true), a), + std::bit_xor()) == false); + } +} + +int main() { + test_reduce_simd(); + test_reduce_mask(); +}