Index: libcxx/include/experimental/simd =================================================================== --- libcxx/include/experimental/simd +++ libcxx/include/experimental/simd @@ -592,6 +592,7 @@ #include #include #include +#include #include #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -1341,6 +1342,20 @@ return __concat_array(__arr, std::make_index_sequence<_Np>()); } +template +simd<_Up, _Abi> __bit_cast(const simd<_Tp, _Abi>& __v) { + static_assert(std::is_arithmetic<_Up>::value, ""); + static_assert(sizeof(_Up) == sizeof(_Tp), ""); + simd<_Up, _Abi> __ret; + for (size_t __i = 0; __i < __v.size(); __i++) { + _Up __tmp; + _Tp __elem = __v[__i]; + memcpy(&__tmp, &__elem, sizeof(__elem)); + __ret[__i] = __tmp; + } + return __ret; +} + struct __simd_mask_friend { template static fixed_size_simd_mask<_Tp, simd_size<_Tp, _Abi>::value> @@ -1428,6 +1443,31 @@ } return concat(__arr); } + + template + static simd<_Tp, _Abi> __simd_select(const simd<_Tp, _Abi>& __false_values, + const simd<_Tp, _Abi>& __true_values, + const simd_mask<_Tp, _Abi>& __m) { + using __unsigned = typename __unsigned_traits::type; + return __bit_cast<_Tp>( + (__bit_cast<__unsigned>(__false_values) & ~__m.__s_) + + (__bit_cast<__unsigned>(__true_values) & __m.__s_)); + } + + template + static simd_mask<_Tp, _Abi> + __simd_select(const simd_mask<_Tp, _Abi>& __false_values, + const simd_mask<_Tp, _Abi>& __true_values, + const simd_mask<_Tp, _Abi>& __m) { + using __unsigned = typename __unsigned_traits::type; + return __simd_select(__false_values.__s_, __true_values.__s_, + simd_mask<__unsigned, _Abi>(__m.__s_)); + } + + template + static _Tp __simd_select(_Tp __false_value, _Tp __true_value, bool __m) { + return __m ? __true_value : __false_value; + } }; template @@ -1593,38 +1633,6 @@ return 0; } -// masked assignment [simd.whereexpr] -template -class const_where_expression; -template -class where_expression; - -// masked assignment [simd.mask.where] -template -where_expression, simd<_Tp, _Abi>> -where(const typename simd<_Tp, _Abi>::mask_type&, simd<_Tp, _Abi>&) noexcept; - -template -const_where_expression, const simd<_Tp, _Abi>> -where(const typename simd<_Tp, _Abi>::mask_type&, - const simd<_Tp, _Abi>&) noexcept; - -template -where_expression, simd_mask<_Tp, _Abi>> -where(const typename __nodeduce>::type&, - simd_mask<_Tp, _Abi>&) noexcept; - -template -const_where_expression, const simd_mask<_Tp, _Abi>> -where(const typename __nodeduce>::type&, - const simd_mask<_Tp, _Abi>&) noexcept; - -template -where_expression where(bool, _Tp&) noexcept; - -template -const_where_expression where(bool, const _Tp&) noexcept; - // reductions [simd.reductions] template > _Tp reduce(const simd<_Tp, _Abi>& __v, _BinaryOp __op = _BinaryOp()) { @@ -1635,36 +1643,6 @@ return __acc; } -template -typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - typename _SimdType::value_type neutral_element, _BinaryOp binary_op); - -template -typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - plus binary_op = {}); - -template -typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - multiplies binary_op); - -template -typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - bit_and binary_op); - -template -typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - bit_or binary_op); - -template -typename _SimdType::value_type -reduce(const const_where_expression<_MaskType, _SimdType>&, - bit_xor binary_op); - template _Tp hmin(const simd<_Tp, _Abi>& __v) { _Tp __acc = __v[0]; @@ -1674,10 +1652,6 @@ return __acc; } -template -typename _SimdType::value_type -hmin(const const_where_expression<_MaskType, _SimdType>&); - template _Tp hmax(const simd<_Tp, _Abi>& __v) { _Tp __acc = __v[0]; @@ -1687,10 +1661,6 @@ return __acc; } -template -typename _SimdType::value_type -hmax(const const_where_expression<_MaskType, _SimdType>&); - // algorithms [simd.alg] template simd<_Tp, _Abi> min(const simd<_Tp, _Abi>& __a, @@ -1724,53 +1694,6 @@ return min(max(__v, __lo), __hi); } -// [simd.whereexpr] -// TODO implement where expressions. -template -class const_where_expression { -public: - const_where_expression(const const_where_expression&) = delete; - const_where_expression& operator=(const const_where_expression&) = delete; - typename remove_const<_Tp>::type operator-() const&&; - template - void copy_to(_Up*, _Flags) const&&; -}; - -template -class where_expression : public const_where_expression<_MaskType, _Tp> { -public: - where_expression(const where_expression&) = delete; - where_expression& operator=(const where_expression&) = delete; - template - void operator=(_Up&&); - template - void operator+=(_Up&&); - template - void operator-=(_Up&&); - template - void operator*=(_Up&&); - template - void operator/=(_Up&&); - template - void operator%=(_Up&&); - template - void operator&=(_Up&&); - template - void operator|=(_Up&&); - template - void operator^=(_Up&&); - template - void operator<<=(_Up&&); - template - void operator>>=(_Up&&); - void operator++(); - void operator++(int); - void operator--(); - void operator--(int); - template - void copy_from(const _Up*, _Flags); -}; - // [simd.class] template class simd { @@ -2258,6 +2181,340 @@ } }; +template +void __mask_copy_to(const simd<_Tp, _Abi>& __v, const simd_mask<_Tp, _Abi>& __m, + _Up* __buffer, _Flags) { + for (size_t __i = 0; __i < __v.size(); __i++) { + if (__m[__i]) { + __buffer[__i] = static_cast<_Up>(__v[__i]); + } + } +} + +template +void __mask_copy_to(const simd_mask<_Tp, _Abi>& __v, + const simd_mask<_Tp, _Abi>& __m, _Up* __buffer, _Flags) { + for (size_t __i = 0; __i < __v.size(); __i++) { + if (__m[__i]) { + __buffer[__i] = static_cast<_Up>(__v[__i]); + } + } +} + +template +void __mask_copy_to(_Tp __v, bool __m, _Up* __buffer, _Flags) { + if (__m) { + *__buffer = static_cast<_Up>(__v); + } +} + +template +void __mask_copy_from(simd<_Tp, _Abi>& __v, const simd_mask<_Tp, _Abi>& __m, + const _Up* __buffer, _Flags) { + // TODO: optimize for overaligned flags + for (size_t __i = 0; __i < __v.size(); __i++) { + if (__m[__i]) { + __v[__i] = static_cast<_Tp>(__buffer[__i]); + } + } +} + +template +void __mask_copy_from(simd_mask<_Tp, _Abi>& __v, + const simd_mask<_Tp, _Abi>& __m, const _Up* __buffer, + _Flags) { + // TODO: optimize based on bool's bit pattern. + for (size_t __i = 0; __i < __v.size(); __i++) { + if (__m[__i]) { + __v[__i] = static_cast(__buffer[__i]); + } + } +} + +template +void __mask_copy_from(_Tp& __v, bool __m, const _Up* __buffer, _Flags) { + if (__m) { + __v = static_cast<_Tp>(*__buffer); + } +} + +template +struct __simd_value_type_traits { + static_assert(std::is_arithmetic<_ValueType>::value, ""); + using type = _ValueType; +}; + +template +struct __simd_value_type_traits> { + static_assert(std::is_arithmetic<_Tp>::value, ""); + using type = _Tp; +}; + +template +struct __simd_value_type_traits> { + static_assert(std::is_arithmetic<_Tp>::value, ""); + using type = _Tp; +}; + +// [simd.whereexpr] +template +class const_where_expression { + static_assert( + std::is_arithmetic::type>::value || + is_simd::type>::value || + is_simd_mask::type>::value, + ""); + + using _Tp = typename __simd_value_type_traits< + typename remove_const<_ValueType>::type>::type; + + typename std::conditional::value, bool, + const _MaskType>::type __m_; + _ValueType& __v_; + + const_where_expression(const _MaskType& __m, _ValueType& __v) + : __m_(__m), __v_(__v) {} + + const_where_expression(const const_where_expression&) = default; + + template + friend class where_expression; + + template + friend const_where_expression, const simd<_Up, _Ap>> + where(const typename simd<_Up, _Ap>::mask_type& __m, + const simd<_Up, _Ap>& __v) noexcept; + + template + friend const_where_expression, const simd_mask<_Up, _Ap>> + where(const typename __nodeduce>::type& __m, + const simd_mask<_Up, _Ap>& __v) noexcept; + + template + friend typename std::enable_if::value, + const_where_expression>::type + where(_Mp __m, const _Up& __v) noexcept; + +public: + const_where_expression& operator=(const const_where_expression&) = delete; + + typename std::remove_const<_ValueType>::type operator-() const&& { + static_assert(!is_simd_mask::type>::value, + "Library extension: operator-() doesn't really make sense " + "when operating on simd_mask<>."); + return __simd_mask_friend::__simd_select(__v_, _ValueType(0), __m_) - + __simd_mask_friend::__simd_select(_ValueType(0), __v_, __m_); + } + + template + typename std::enable_if::value || + !std::is_same<_Tp, bool>::value>::type + copy_to(_Up* __buffer, _Flags) const&& { + __mask_copy_to(__v_, __m_, __buffer, _Flags()); + } +}; + +template +class where_expression : public const_where_expression<_MaskType, _ValueType> { + using _Tp = typename __simd_value_type_traits< + typename remove_const<_ValueType>::type>::type; + + where_expression(const _MaskType& __m, _ValueType& __v) + : const_where_expression<_MaskType, _ValueType>(__m, __v) {} + + where_expression(const where_expression&) = default; + + template + friend where_expression, simd<_Up, _Ap>> + where(const typename simd<_Up, _Ap>::mask_type& __m, + simd<_Up, _Ap>& __v) noexcept; + + template + friend where_expression, simd_mask<_Up, _Ap>> + where(const typename __nodeduce>::type& __m, + simd_mask<_Up, _Ap>& __v) noexcept; + + template + friend typename std::enable_if::value, + where_expression>::type + where(_Mp __m, _Up& __v) noexcept; + +public: + where_expression& operator=(const where_expression&) = delete; + + template + auto operator=(_Up&& __u) + -> decltype(this->__v_ = std::forward<_Up>(__u), void()) { + this->__v_ = __simd_mask_friend::__simd_select( + this->__v_, _ValueType(std::forward<_Up>(__u)), this->__m_); + } + + template + auto operator+=(_Up&& __u) + -> decltype(this->__v_ + std::forward<_Up>(__u), void()) { + *this = this->__v_ + std::forward<_Up>(__u); + } + + template + auto operator-=(_Up&& __u) + -> decltype(this->__v_ - std::forward<_Up>(__u), void()) { + *this = this->__v_ - std::forward<_Up>(__u); + } + + template + auto operator*=(_Up&& __u) + -> decltype(this->__v_ * std::forward<_Up>(__u), void()) { + *this = this->__v_ * std::forward<_Up>(__u); + } + + template + auto operator/=(_Up&& __u) + -> decltype(this->__v_ / std::forward<_Up>(__u), void()) { + this->__v_ = + this->__v_ / + __simd_mask_friend::__simd_select( + _ValueType(1), _ValueType(std::forward<_Up>(__u)), this->__m_); + } + + template + auto operator%=(_Up&& __u) + -> decltype(this->__v_ % std::forward<_Up>(__u), void()) { + this->__v_ = __simd_mask_friend::__simd_select( + this->__v_, + this->__v_ % + __simd_mask_friend::__simd_select( + _ValueType(1), _ValueType(std::forward<_Up>(__u)), this->__m_), + this->__m_); + } + + template + auto operator&=(_Up&& __u) + -> decltype(this->__v_ & std::forward<_Up>(__u), void()) { + *this = this->__v_ & std::forward<_Up>(__u); + } + + template + auto operator|=(_Up&& __u) + -> decltype(this->__v_ | std::forward<_Up>(__u), void()) { + *this = this->__v_ | std::forward<_Up>(__u); + } + + template + auto operator^=(_Up&& __u) + -> decltype(this->__v_ ^ std::forward<_Up>(__u), void()) { + *this = this->__v_ ^ std::forward<_Up>(__u); + } + + template + auto operator<<=(_Up&& __u) + -> decltype(this->__v_ << std::forward<_Up>(__u), void()) { + *this = this->__v_ << std::forward<_Up>(__u); + } + + template + auto operator>>=(_Up&& __u) + -> decltype(this->__v_ >> std::forward<_Up>(__u), void()) { + *this = this->__v_ >> std::forward<_Up>(__u); + } + + void operator++() { *this += _ValueType(1); } + + void operator++(int) { ++*this; } + + void operator--() { *this -= _ValueType(1); } + + void operator--(int) { --*this; } + + template + typename std::enable_if::value || + !std::is_same<_Tp, bool>::value>::type + copy_from(const _Up* __buffer, _Flags) { + __mask_copy_from(this->__v_, this->__m_, __buffer, _Flags()); + } +}; + +template +where_expression, simd<_Tp, _Abi>> +where(const typename simd<_Tp, _Abi>::mask_type& __m, + simd<_Tp, _Abi>& __v) noexcept { + return where_expression, simd<_Tp, _Abi>>(__m, __v); +} + +template +const_where_expression, const simd<_Tp, _Abi>> +where(const typename simd<_Tp, _Abi>::mask_type& __m, + const simd<_Tp, _Abi>& __v) noexcept { + return const_where_expression, const simd<_Tp, _Abi>>( + __m, __v); +} + +template +where_expression, simd_mask<_Tp, _Abi>> +where(const typename __nodeduce>::type& __m, + simd_mask<_Tp, _Abi>& __v) noexcept { + return where_expression, simd_mask<_Tp, _Abi>>(__m, __v); +} + +template +const_where_expression, const simd_mask<_Tp, _Abi>> +where(const typename __nodeduce>::type& __m, + const simd_mask<_Tp, _Abi>& __v) noexcept { + return const_where_expression, + const simd_mask<_Tp, _Abi>>(__m, __v); +} + +template +typename std::enable_if::value, + where_expression>::type +where(_MaskType __m, _Tp& __v) noexcept { + return where_expression(__m, __v); +} + +template +typename std::enable_if::value, + const_where_expression>::type +where(_MaskType __m, const _Tp& __v) noexcept { + return const_where_expression(__m, __v); +} + +template +typename _SimdType::value_type +reduce(const const_where_expression<_MaskType, _SimdType>&, + typename _SimdType::value_type neutral_element, _BinaryOp binary_op); + +template +typename _SimdType::value_type +reduce(const const_where_expression<_MaskType, _SimdType>&, + plus binary_op = {}); + +template +typename _SimdType::value_type +reduce(const const_where_expression<_MaskType, _SimdType>&, + multiplies binary_op); + +template +typename _SimdType::value_type +reduce(const const_where_expression<_MaskType, _SimdType>&, + bit_and binary_op); + +template +typename _SimdType::value_type +reduce(const const_where_expression<_MaskType, _SimdType>&, + bit_or binary_op); + +template +typename _SimdType::value_type +reduce(const const_where_expression<_MaskType, _SimdType>&, + bit_xor binary_op); + +template +typename _SimdType::value_type +hmin(const const_where_expression<_MaskType, _SimdType>&); + +template +typename _SimdType::value_type +hmax(const const_where_expression<_MaskType, _SimdType>&); + _LIBCPP_END_NAMESPACE_EXPERIMENTAL_SIMD #endif /* _LIBCPP_EXPERIMENTAL_SIMD */ Index: libcxx/test/std/experimental/simd/simd.whereexpr/const_where_expression.pass.cpp =================================================================== --- /dev/null +++ libcxx/test/std/experimental/simd/simd.whereexpr/const_where_expression.pass.cpp @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is dual licensed under the MIT and the University of Illinois Open +// Source Licenses. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++98, c++03 + +// +// +// // [simd.whereexpr] +// template +// class const_where_expression { +// const M& mask; // exposition only +// T& data; // exposition only +// public: +// const_where_expression(const const_where_expression&) = delete; +// const_where_expression& operator=(const const_where_expression&) = delete; +// remove_const_t operator-() const &&; +// template void copy_to(U* mem, Flags f) const &&; +// }; + +#include +#include +#include +#include + +using namespace std::experimental::parallelism_v2; + +void test_operator_minus() { + { + const fixed_size_simd a([](int i) { return i; }); + auto b = -where(a < 2, a); + assert(b[0] == 0); + assert(b[1] == -1); + assert(b[2] == 2); + assert(b[3] == 3); + } + assert((-where(true, 3)) == -3); + assert((-where(false, 3)) == 3); +} + +void test_copy_to() { + { + const fixed_size_simd a([](int i) { return i - 2; }); + int buffer[] = {1, 2, 3, 4}; + where(a < 0, a).copy_to(buffer, element_aligned_tag()); + assert(buffer[0] == -2); + assert(buffer[1] == -1); + assert(buffer[2] == 3); + assert(buffer[3] == 4); + } + { + const fixed_size_simd a([](int i) { return i - 2; }); + int buffer[] = {1, 2, 3, 4}; + where(a >= 0, a).copy_to(buffer, element_aligned_tag()); + assert(buffer[0] == 1); + assert(buffer[1] == 2); + assert(buffer[2] == 0); + assert(buffer[3] == 1); + } + { + fixed_size_simd_mask a; + { + bool input[] = {false, true, true, false}; + a.copy_from(input, element_aligned_tag()); + } + { + bool buffer[] = {true, true, false, false}; + where(a, a).copy_to(buffer, element_aligned_tag()); + assert(buffer[0]); + assert(buffer[1]); + assert(buffer[2]); + assert(!buffer[3]); + } + { + bool buffer[] = {true, true, false, false}; + where(!a, a).copy_to(buffer, element_aligned_tag()); + assert(!buffer[0]); + assert(buffer[1]); + assert(!buffer[2]); + assert(!buffer[3]); + } + } + { + int b = 1; + where(true, 3).copy_to(&b, element_aligned_tag()); + assert(b == 3); + } + { + int b = 1; + where(false, 3).copy_to(&b, element_aligned_tag()); + assert(b == 1); + } +} + +int main() { + test_operator_minus(); + test_copy_to(); +} Index: libcxx/test/std/experimental/simd/simd.whereexpr/where.pass.cpp =================================================================== --- /dev/null +++ libcxx/test/std/experimental/simd/simd.whereexpr/where.pass.cpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is dual licensed under the MIT and the University of Illinois Open +// Source Licenses. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++98, c++03 + +// +// +// // masked assignment [simd.mask.where] +// template +// where_expression, simd> +// where(const typename simd::mask_type&, simd&) noexcept; +// +// template +// const_where_expression, const simd> +// where(const typename simd::mask_type&, const simd&) noexcept; +// +// template +// where_expression, simd_mask> +// where(const nodeduce_t>&, simd_mask&) noexcept; +// +// template +// const_where_expression, const simd_mask> +// where(const nodeduce_t>&, const simd_mask&) noexcept; +// +// template where_expression where(see below k, T& d) noexcept; +// +// template +// const_where_expression where(see below k, const T& d) noexcept; + +#include +#include +#include +#include + +using namespace std::experimental::parallelism_v2; + +void compile_const_where() { + { + const native_simd a{}; + static_assert( + std::is_same, + const native_simd>>::value, + ""); + } + { + const native_simd_mask a{}; + static_assert( + std::is_same< + decltype(where(a, a)), + const_where_expression, + const native_simd_mask>>::value, + ""); + } + { + const bool b = true; + static_assert(std::is_same>::value, + ""); + } +} + +void compile_where() { + { + native_simd a; + static_assert( + std::is_same< + decltype(where(a < 2, a)), + where_expression, native_simd>>::value, + ""); + } + { + native_simd_mask a; + static_assert(std::is_same, + native_simd_mask>>::value, + ""); + } + { + int v = 3; + static_assert( + std::is_same>::value, + ""); + } +} + +int main() {} Index: libcxx/test/std/experimental/simd/simd.whereexpr/where_expression.pass.cpp =================================================================== --- /dev/null +++ libcxx/test/std/experimental/simd/simd.whereexpr/where_expression.pass.cpp @@ -0,0 +1,366 @@ +//===----------------------------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is dual licensed under the MIT and the University of Illinois Open +// Source Licenses. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++98, c++03 + +// +// +// // [simd.whereexpr] +// template +// class where_expression : public const_where_expression { +// public: +// where_expression(const where_expression&) = delete; +// where_expression& operator=(const where_expression&) = delete; +// template void operator=(U&& x); +// template void operator+=(U&& x); +// template void operator-=(U&& x); +// template void operator*=(U&& x); +// template void operator/=(U&& x); +// template void operator%=(U&& x); +// template void operator&=(U&& x); +// template void operator|=(U&& x); +// template void operator^=(U&& x); +// template void operator<<=(U&& x); +// template void operator>>=(U&& x); +// void operator++(); +// void operator++(int); +// void operator--(); +// void operator--(int); +// template void copy_from(const U* mem, Flags); +// }; + +#include +#include +#include +#include + +using namespace std::experimental::parallelism_v2; + +void test_operators_simd() { + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) = -1; + assert(a[0] == -1); + assert(a[1] == -1); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) = fixed_size_simd(-1); + assert(a[0] == -1); + assert(a[1] == -1); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) += -1; + assert(a[0] == -1); + assert(a[1] == 0); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) += fixed_size_simd(-1); + assert(a[0] == -1); + assert(a[1] == 0); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) -= -1; + assert(a[0] == 1); + assert(a[1] == 2); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) -= fixed_size_simd(-1); + assert(a[0] == 1); + assert(a[1] == 2); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) *= -1; + assert(a[0] == 0); + assert(a[1] == -1); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) *= fixed_size_simd(-1); + assert(a[0] == 0); + assert(a[1] == -1); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return 3 * i; }); + where(a >= 6, a) /= 2; + assert(a[0] == 0); + assert(a[1] == 3); + assert(a[2] == 3); + assert(a[3] == 4); + } + { + fixed_size_simd a([](int i) { return 3 * i; }); + where(a >= 6, a) /= fixed_size_simd(2); + assert(a[0] == 0); + assert(a[1] == 3); + assert(a[2] == 3); + assert(a[3] == 4); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a % 2 == 1, a) /= + fixed_size_simd([](int i) { return i % 2 * 2; }); + assert(a[0] == 0); + assert(a[1] == 0); + assert(a[2] == 2); + assert(a[3] == 1); + } + { + fixed_size_simd a([](int i) { return 3 * i; }); + where(a >= 6, a) %= 2; + assert(a[0] == 0); + assert(a[1] == 3); + assert(a[2] == 0); + assert(a[3] == 1); + } + { + fixed_size_simd a([](int i) { return 3 * i; }); + where(a >= 6, a) %= fixed_size_simd(2); + assert(a[0] == 0); + assert(a[1] == 3); + assert(a[2] == 0); + assert(a[3] == 1); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a % 2 == 1, a) %= + fixed_size_simd([](int i) { return i % 2 * 2; }); + assert(a[0] == 0); + assert(a[1] == 1); + assert(a[2] == 2); + assert(a[3] == 1); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a > -2, a) &= 1; + assert(a[0] == 0); + assert(a[1] == 1); + assert(a[2] == 0); + assert(a[3] == 1); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a >= 2, a) &= fixed_size_simd(1); + assert(a[0] == 0); + assert(a[1] == 1); + assert(a[2] == 0); + assert(a[3] == 1); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) |= 2; + assert(a[0] == 2); + assert(a[1] == 3); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) |= fixed_size_simd(2); + assert(a[0] == 2); + assert(a[1] == 3); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) ^= 1; + assert(a[0] == 1); + assert(a[1] == 0); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) ^= fixed_size_simd(1); + assert(a[0] == 1); + assert(a[1] == 0); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) <<= 1; + assert(a[0] == 0); + assert(a[1] == 2); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a) <<= fixed_size_simd(1); + assert(a[0] == 0); + assert(a[1] == 2); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + fixed_size_simd a([](int i) { return 2 * i; }); + where(a < 4, a) >>= 1; + assert(a[0] == 0); + assert(a[1] == 1); + assert(a[2] == 4); + assert(a[3] == 6); + } + { + fixed_size_simd a([](int i) { return 2 * i; }); + where(a < 4, a) >>= fixed_size_simd(1); + assert(a[0] == 0); + assert(a[1] == 1); + assert(a[2] == 4); + assert(a[3] == 6); + } +} + +void test_operators_mask() { + { + fixed_size_simd_mask a; + a[0] = false; + a[1] = true; + a[2] = true; + a[3] = false; + where(a, a) = fixed_size_simd_mask(false); + assert(!a[0]); + assert(!a[1]); + assert(!a[2]); + assert(!a[3]); + } + { + fixed_size_simd_mask a; + a[0] = false; + a[1] = true; + a[2] = true; + a[3] = false; + where(a, a) &= fixed_size_simd_mask(false); + assert(!a[0]); + assert(!a[1]); + assert(!a[2]); + assert(!a[3]); + } + { + fixed_size_simd_mask a; + a[0] = false; + a[1] = true; + a[2] = true; + a[3] = false; + where(!a, a) |= fixed_size_simd_mask(true); + assert(a[0]); + assert(a[1]); + assert(a[2]); + assert(a[3]); + } + { + fixed_size_simd_mask a; + a[0] = false; + a[1] = true; + a[2] = true; + a[3] = false; + where(a, a) ^= fixed_size_simd_mask(true); + assert(!a[0]); + assert(!a[1]); + assert(!a[2]); + assert(!a[3]); + } + { + fixed_size_simd_mask a; + a[0] = false; + a[1] = true; + a[2] = true; + a[3] = false; + where(!a, a) ^= fixed_size_simd_mask(true); + assert(a[0]); + assert(a[1]); + assert(a[2]); + assert(a[3]); + } +} + +void test_copy_from() { + { + const int buffer[] = {-1, -2, -3, -4}; + fixed_size_simd a([](int i) { return i; }); + where(a < 2, a).copy_from(buffer, element_aligned_tag()); + assert(a[0] == -1); + assert(a[1] == -2); + assert(a[2] == 2); + assert(a[3] == 3); + } + { + const int buffer[] = {-1, -2, -3, -4}; + fixed_size_simd a([](int i) { return i; }); + where(a >= 2, a).copy_from(buffer, element_aligned_tag()); + assert(a[0] == 0); + assert(a[1] == 1); + assert(a[2] == -3); + assert(a[3] == -4); + } + { + fixed_size_simd_mask a; + const bool input[] = {false, true, true, false}; + a.copy_from(input, element_aligned_tag()); + + const bool buffer[] = {true, true, false, false}; + where(a, a).copy_from(buffer, element_aligned_tag()); + assert(!a[0]); + assert(a[1]); + assert(!a[2]); + assert(!a[3]); + } + { + fixed_size_simd_mask a; + const bool input[] = {false, true, true, false}; + a.copy_from(input, element_aligned_tag()); + + const bool buffer[] = {true, true, false, false}; + where(!a, a).copy_from(buffer, element_aligned_tag()); + assert(a[0]); + assert(a[1]); + assert(a[2]); + assert(!a[3]); + } + { + const int b = 1; + int a = 3; + where(true, a).copy_from(&b, element_aligned_tag()); + assert(a == 1); + } + { + const int b = 1; + int a = 3; + where(false, a).copy_from(&b, element_aligned_tag()); + assert(b == 1); + } +} + +int main() { + test_operators_simd(); + test_operators_mask(); + test_copy_from(); +}