diff --git a/libc/src/string/string_utils.h b/libc/src/string/string_utils.h --- a/libc/src/string/string_utils.h +++ b/libc/src/string/string_utils.h @@ -9,6 +9,7 @@ #ifndef LIBC_SRC_STRING_STRING_UTILS_H #define LIBC_SRC_STRING_STRING_UTILS_H +#include "src/__support/common.h" #include "utils/CPP/Bitset.h" #include // size_t @@ -58,23 +59,27 @@ static inline char *string_token(char *__restrict src, const char *__restrict delimiter_string, char **__restrict saveptr) { + // Return nullptr immediately if both src AND saveptr are nullptr + if (unlikely(src == nullptr && ((src = *saveptr) == nullptr))) + return nullptr; + cpp::Bitset<256> delimiter_set; - for (; *delimiter_string; ++delimiter_string) + for (; *delimiter_string != '\0'; ++delimiter_string) delimiter_set.set(*delimiter_string); - src = src ? src : *saveptr; - for (; *src && delimiter_set.test(*src); ++src) + for (; *src != '\0' && delimiter_set.test(*src); ++src) ; - if (!*src) { + if (*src == '\0') { *saveptr = src; return nullptr; } char *token = src; - for (; *src && !delimiter_set.test(*src); ++src) - ; - if (*src) { - *src = '\0'; - ++src; + for (; *src != '\0'; ++src) { + if (delimiter_set.test(*src)) { + *src = '\0'; + ++src; + break; + } } *saveptr = src; return token; diff --git a/libc/test/src/string/strtok_r_test.cpp b/libc/test/src/string/strtok_r_test.cpp --- a/libc/test/src/string/strtok_r_test.cpp +++ b/libc/test/src/string/strtok_r_test.cpp @@ -80,6 +80,18 @@ ASSERT_STREQ(__llvm_libc::strtok_r(nullptr, ",", &reserve), nullptr); } +TEST(LlvmLibcStrTokReentrantTest, + ShouldReturnNullptrWhenBothSrcAndSaveptrAreNull) { + char *src = nullptr; + char *reserve = nullptr; + // Ensure that instead of crashing if src and reserve are null, nullptr is + // returned + ASSERT_STREQ(__llvm_libc::strtok_r(src, ",", &reserve), nullptr); + // And that neither src nor reserve are changed when that happens + ASSERT_STREQ(src, nullptr); + ASSERT_STREQ(reserve, nullptr); +} + TEST(LlvmLibcStrTokReentrantTest, SubsequentCallsShouldFindFollowingDelimiters) { char src[] = "12,34.56";