Index: libc/src/string/CMakeLists.txt =================================================================== --- libc/src/string/CMakeLists.txt +++ libc/src/string/CMakeLists.txt @@ -295,7 +295,7 @@ function(add_memcmp memcmp_name) add_implementation(memcmp ${memcmp_name} - SRCS ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp + SRCS ${LIBC_MEMCMP_SRC} HDRS ${LIBC_SOURCE_DIR}/src/string/memcmp.h DEPENDS .memory_utils.memory_utils @@ -307,13 +307,19 @@ endfunction() if(${LIBC_TARGET_ARCHITECTURE_IS_X86}) + set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp) add_memcmp(memcmp_x86_64_opt_sse2 COMPILE_OPTIONS -march=k8 REQUIRE SSE2) add_memcmp(memcmp_x86_64_opt_sse4 COMPILE_OPTIONS -march=nehalem REQUIRE SSE4_2) add_memcmp(memcmp_x86_64_opt_avx2 COMPILE_OPTIONS -march=haswell REQUIRE AVX2) add_memcmp(memcmp_x86_64_opt_avx512 COMPILE_OPTIONS -march=skylake-avx512 REQUIRE AVX512F) add_memcmp(memcmp_opt_host COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE}) add_memcmp(memcmp) +elseif(${LIBC_TARGET_ARCHITECTURE_IS_AARCH64}) + set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/aarch64/memcmp.cpp) + add_memcmp(memcmp) + add_memcmp(memcmp_opt_host COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE}) else() + set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp) add_memcmp(memcmp_opt_host COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE}) add_memcmp(memcmp) endif() Index: libc/src/string/aarch64/memcmp.cpp =================================================================== --- /dev/null +++ libc/src/string/aarch64/memcmp.cpp @@ -0,0 +1,61 @@ +//===-- Implementation of memcmp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/string/memcmp.h" +#include "src/__support/common.h" +#include "src/string/memory_utils/elements.h" +#include // size_t + +namespace __llvm_libc { +namespace aarch64 { + +static int memcmp_impl(const void *lhs, const void *rhs, size_t count) { + const char *l = reinterpret_cast(lhs); + const char *r = reinterpret_cast(rhs); + if (count == 0) + return 0; + if (count == 1) + return ThreeWayCompare<_1>(l, r); + else if (count == 2) + return ThreeWayCompare<_2>(l, r); + else if (count == 3) + return ThreeWayCompare<_3>(l, r); + else if (count < 8) + return ThreeWayCompare>(l, r, count); + else if (count < 16) + return ThreeWayCompare>(l, r, count); + else if (count < 128) { + size_t offset = 0; + if (Equals<_16>(l, r)) { + if (count < 32) + offset = count - 16; + else { + offset += 16; + if (Equals<_16>(l + offset, r + offset)) { + offset += 16; + if (count < 64) + return ThreeWayCompare>(l, r, count); + if (count < 128) + return ThreeWayCompare>(l + offset, r + offset, + count - offset); + } else + offset = count - 32; + } + } + return ThreeWayCompare<_16>(l + offset, r + offset); + } else + return ThreeWayCompare::Then>>(l, r, count); +} +} // namespace aarch64 + +LLVM_LIBC_FUNCTION(int, memcmp, + (const void *lhs, const void *rhs, size_t count)) { + return aarch64::memcmp_impl(lhs, rhs, count); +} + +} // namespace __llvm_libc Index: libc/src/string/memory_utils/elements.h =================================================================== --- libc/src/string/memory_utils/elements.h +++ libc/src/string/memory_utils/elements.h @@ -211,8 +211,8 @@ } static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) { - if (const int result = T::ThreeWayCompare(lhs, rhs)) - return result; + if (!T::Equals(lhs, rhs)) + return T::ThreeWayCompare(lhs, rhs); return Tail::ThreeWayCompare(lhs, rhs, size); } @@ -251,8 +251,8 @@ static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) { for (size_t offset = 0; offset < size - T::kSize; offset += T::kSize) - if (const int result = T::ThreeWayCompare(lhs + offset, rhs + offset)) - return result; + if (!T::Equals(lhs + offset, rhs + offset)) + return T::ThreeWayCompare(lhs + offset, rhs + offset); return Tail::ThreeWayCompare(lhs, rhs, size); } @@ -327,8 +327,8 @@ } static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) { - if (const int result = AlignmentT::ThreeWayCompare(lhs, rhs)) - return result; + if (!AlignmentT::Equals(lhs, rhs)) + return AlignmentT::ThreeWayCompare(lhs, rhs); internal::AlignHelper::Bump(lhs, rhs, size); return NextT::ThreeWayCompare(lhs, rhs, size); } @@ -370,12 +370,18 @@ #endif } +#if __has_builtin(__builtin_memcmp_inline) +#define LLVM_LIBC_MEMCMP __builtin_memcmp_inline +#else +#define LLVM_LIBC_MEMCMP __builtin_memcmp +#endif + static bool Equals(const char *lhs, const char *rhs) { - return __builtin_memcmp(lhs, rhs, kSize) == 0; + return LLVM_LIBC_MEMCMP(lhs, rhs, kSize) == 0; } static int ThreeWayCompare(const char *lhs, const char *rhs) { - return __builtin_memcmp(lhs, rhs, kSize); + return LLVM_LIBC_MEMCMP(lhs, rhs, kSize); } static void SplatSet(char *dst, const unsigned char value) { @@ -428,6 +434,7 @@ Store(dst, GetSplattedValue(value)); } + static int ScalarThreeWayCompare(T a, T b); private: static T Load(const char *ptr) { T value; @@ -440,7 +447,6 @@ static T GetSplattedValue(const unsigned char value) { return T(~0) / T(0xFF) * T(value); } - static int ScalarThreeWayCompare(T a, T b); }; template <> @@ -457,23 +463,15 @@ } template <> inline int Scalar::ScalarThreeWayCompare(uint32_t a, uint32_t b) { - const int64_t la = Endian::ToBigEndian(a); - const int64_t lb = Endian::ToBigEndian(b); - if (la < lb) - return -1; - if (la > lb) - return 1; - return 0; + const uint32_t la = Endian::ToBigEndian(a); + const uint32_t lb = Endian::ToBigEndian(b); + return la > lb ? 1 : la < lb ? -1 : 0; } template <> inline int Scalar::ScalarThreeWayCompare(uint64_t a, uint64_t b) { - const __int128_t la = Endian::ToBigEndian(a); - const __int128_t lb = Endian::ToBigEndian(b); - if (la < lb) - return -1; - if (la > lb) - return 1; - return 0; + const uint64_t la = Endian::ToBigEndian(a); + const uint64_t lb = Endian::ToBigEndian(b); + return la > lb ? 1 : la < lb ? -1 : 0; } using UINT8 = Scalar; // 1 Byte @@ -495,5 +493,6 @@ } // namespace __llvm_libc #include +#include #endif // LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_H Index: libc/src/string/memory_utils/elements_aarch64.h =================================================================== --- /dev/null +++ libc/src/string/memory_utils/elements_aarch64.h @@ -0,0 +1,60 @@ +//===-- Elementary operations for aarch64 --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H +#define LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H + +#include +#include // size_t +#include // uint8_t, uint16_t, uint32_t, uint64_t + +namespace __llvm_libc { +namespace aarch64 { + +using _1 = __llvm_libc::builtin::_1; +using _2 = __llvm_libc::builtin::_2; +using _3 = __llvm_libc::builtin::_3; +using _4 = __llvm_libc::builtin::_4; +using _8 = __llvm_libc::builtin::_8; +using _16 = __llvm_libc::builtin::_16; + +struct N32 { + static constexpr size_t kSize = 32; + static bool Equals(const char *lhs, const char *rhs) { + uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs); + uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs); + uint8x16_t l_1 = vld1q_u8((const uint8_t *)(lhs + 16)); + uint8x16_t r_1 = vld1q_u8((const uint8_t *)(rhs + 16)); + uint8x16_t temp = vpmaxq_u8(veorq_u8(l_0, r_0), veorq_u8(l_1, r_1)); + uint64_t res = + vgetq_lane_u64(vreinterpretq_u64_u8(vpmaxq_u8(temp, temp)), 0); + return res == 0; + } + static int ThreeWayCompare(const char *lhs, const char *rhs) { + uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs); + uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs); + uint8x16_t l_1 = vld1q_u8((const uint8_t *)(lhs + 16)); + uint8x16_t r_1 = vld1q_u8((const uint8_t *)(rhs + 16)); + uint8x16_t temp = vpmaxq_u8(veorq_u8(l_0, r_0), veorq_u8(l_1, r_1)); + uint64_t res = + vgetq_lane_u64(vreinterpretq_u64_u8(vpmaxq_u8(temp, temp)), 0); + if (res == 0) + return 0; + size_t index = (__builtin_ctzl(res) >> 3) << 2; + uint32_t l = *((const uint32_t *)(lhs + index)); + uint32_t r = *((const uint32_t *)(rhs + index)); + return __llvm_libc::scalar::_4::ScalarThreeWayCompare(l, r); + } +}; + +using _32 = N32; + +} // namespace aarch64 +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H Index: libc/test/src/string/CMakeLists.txt =================================================================== --- libc/test/src/string/CMakeLists.txt +++ libc/test/src/string/CMakeLists.txt @@ -52,27 +52,6 @@ libc.src.string.memchr ) -add_libc_unittest( - memcmp_test - SUITE - libc_string_unittests - SRCS - memcmp_test.cpp - DEPENDS - libc.src.string.memcmp -) - -add_libc_unittest( - memmove_test - SUITE - libc_string_unittests - SRCS - memmove_test.cpp - DEPENDS - libc.src.string.memcmp - libc.src.string.memmove -) - add_libc_unittest( strchr_test SUITE @@ -209,3 +188,5 @@ add_libc_multi_impl_test(memcpy SRCS memcpy_test.cpp) add_libc_multi_impl_test(memset SRCS memset_test.cpp) add_libc_multi_impl_test(bzero SRCS bzero_test.cpp) +add_libc_multi_impl_test(memcmp SRCS memcmp_test.cpp) +add_libc_multi_impl_test(memmove SRCS memmove_test.cpp) Index: libc/test/src/string/memcmp_test.cpp =================================================================== --- libc/test/src/string/memcmp_test.cpp +++ libc/test/src/string/memcmp_test.cpp @@ -32,3 +32,19 @@ const char *rhs = "ab"; EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, 2), 1); } + +TEST(LlvmLibcMemcmpTest, Sweep) { + char lhs[1024]; + char rhs[1024]; + + for (int i = 0; i < 1024; ++i) { + lhs[i] = 'a'; + rhs[i] = 'a'; + EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, i), 0); + } + for (int i = 0; i < 1024; ++i) { + rhs[i] = 'b'; + EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, 1024), -1); + rhs[i] = 'a'; + } +}