diff --git a/libc/src/string/memory_utils/elements_aarch64.h b/libc/src/string/memory_utils/elements_aarch64.h --- a/libc/src/string/memory_utils/elements_aarch64.h +++ b/libc/src/string/memory_utils/elements_aarch64.h @@ -81,9 +81,34 @@ using _3 = __llvm_libc::scalar::_3; using _4 = __llvm_libc::scalar::_4; using _8 = __llvm_libc::scalar::_8; -using _16 = __llvm_libc::scalar::_16; #ifdef __ARM_NEON +struct N16 { + static constexpr size_t SIZE = 16; + static bool equals(const char *lhs, const char *rhs) { + uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs); + uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs); + uint8x16_t cmp = vceqq_u8(l_0, r_0); + uint8x8_t narrowed = vshrn_n_u16(vreinterpretq_u16_u8(cmp), 4); + return (vget_lane_u64(narrowed, 0) == ~0ull); + } + + static int three_way_compare(const char *lhs, const char *rhs) { + uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs); + uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs); + uint8x16_t cmp = vceqq_u8(l_0, r_0); + uint8x8_t narrowed = vshrn_n_u16(vreinterpretq_u16_u8(cmp), 4); + uint64_t result = ~vget_lane_u64(narrowed, 0); + if (result == 0) { + return 0; + } + uint32_t bit = __builtin_ctzll(result) >> 2; + const int ca = (unsigned char)lhs[bit]; + const int cb = (unsigned char)rhs[bit]; + return ca - cb; + } +}; + struct N32 { static constexpr size_t SIZE = 32; static bool equals(const char *lhs, const char *rhs) { @@ -113,9 +138,11 @@ } }; +using _16 = N16; using _32 = N32; using _64 = Repeated<_32, 2>; #else +using _16 = __llvm_libc::scalar::_16; using _32 = __llvm_libc::scalar::_32; using _64 = __llvm_libc::scalar::_64; #endif // __ARM_NEON