diff --git a/libc/src/string/CMakeLists.txt b/libc/src/string/CMakeLists.txt --- a/libc/src/string/CMakeLists.txt +++ b/libc/src/string/CMakeLists.txt @@ -8,6 +8,8 @@ libc.src.__support.CPP.bitset .memory_utils.memcpy_implementation .memory_utils.memset_implementation + FLAGS + LLVM_LIBC_STRLEN_UNSAFE ) add_entrypoint_object( @@ -157,6 +159,7 @@ HDRS strlen.h DEPENDS + .string_utils libc.include.string ) diff --git a/libc/src/string/string_utils.h b/libc/src/string/string_utils.h --- a/libc/src/string/string_utils.h +++ b/libc/src/string/string_utils.h @@ -18,15 +18,78 @@ namespace __llvm_libc { namespace internal { -// Returns the length of a string, denoted by the first occurrence -// of a null terminator. -static inline size_t string_length(const char *src) { +template constexpr T repeat_byte(T byte) { + constexpr size_t BITS_IN_BYTE = 8; + constexpr size_t BYTE_MASK = 0xff; + T result = 0; + byte = byte & BYTE_MASK; + for (size_t i = 0; i < sizeof(T); ++i) + result = (result << BITS_IN_BYTE) | byte; + return result; +} + +// The goal of this function is to take in a block of arbitrary size and return +// if it has any bytes equal to zero without branching. This is done by +// transforming the block such that zero bytes become non-zero and non-zero +// bytes become zero. +// The first transformation relies on the properties of carrying in arithmetic +// subtraction. Specifically, if 0x01 is subtracted from a byte that is 0x00, +// then the result for that byte must be equal to 0xff (or 0xfe if the next byte +// needs a carry as well). +// The next transformation is a simple mask. All zero bytes will have the high +// bit set after the subtraction, so each byte is masked with 0x80. This narrows +// the set of bytes that result in a non-zero value to only zero bytes and bytes +// with the high bit and any other bit set. +// The final transformation masks the result of the previous transformations +// with the inverse of the original byte. This means that any byte that had the +// high bit set will no longer have it set, narrowing the list of bytes which +// result in non-zero values to just the zero byte. +template constexpr bool has_zeroes(T block) { + constexpr T LOW_BITS = repeat_byte(0x01); + constexpr T HIGH_BITS = repeat_byte(0x80); + T subtracted = block - LOW_BITS; + T inverted = ~block; + return (subtracted & inverted & HIGH_BITS) != 0; +} + +template +static inline size_t string_length_unsafe(const char *src) { + const char *char_ptr = src; + // Step 1: read 1 byte at a time to align to block size + for (; reinterpret_cast(char_ptr) % sizeof(T) != 0; ++char_ptr) { + if (*char_ptr == 0) + return char_ptr - src; + } + // Step 2: read blocks + for (const T *block_ptr = reinterpret_cast(char_ptr); + !has_zeroes(*block_ptr); ++block_ptr) { + char_ptr = reinterpret_cast(block_ptr); + } + // Step 3: find the zero in the block + for (; *char_ptr != 0; ++char_ptr) { + ; + } + return char_ptr - src; +} + +static inline size_t string_length_safe(const char *src) { size_t length; for (length = 0; *src; ++src, ++length) ; return length; } +// Returns the length of a string, denoted by the first occurrence +// of a null terminator. +static inline size_t string_length(const char *src) { +#ifdef LLVM_LIBC_STRLEN_UNSAFE + // TODO: pick size of var based on requested size. + return string_length_unsafe(src); +#else + return string_length_safe(src); +#endif +} + // Returns the first occurrence of 'ch' within the first 'n' characters of // 'src'. If 'ch' is not found, returns nullptr. static inline void *find_first_character(const unsigned char *src,