diff --git a/libc/src/string/CMakeLists.txt b/libc/src/string/CMakeLists.txt --- a/libc/src/string/CMakeLists.txt +++ b/libc/src/string/CMakeLists.txt @@ -58,6 +58,8 @@ memchr.cpp HDRS memchr.h + DEPENDS + .string_utils ) add_entrypoint_object( @@ -74,6 +76,9 @@ strstr.cpp HDRS strstr.h + DEPENDS + .string_utils + libc.utils.CPP.standalone_cpp ) add_entrypoint_object( diff --git a/libc/src/string/memchr.cpp b/libc/src/string/memchr.cpp --- a/libc/src/string/memchr.cpp +++ b/libc/src/string/memchr.cpp @@ -8,17 +8,14 @@ #include "src/string/memchr.h" #include "src/__support/common.h" +#include "src/string/string_utils.h" #include namespace __llvm_libc { // TODO: Look at performance benefits of comparing words. void *LLVM_LIBC_ENTRYPOINT(memchr)(const void *src, int c, size_t n) { - const unsigned char *str = reinterpret_cast(src); - const unsigned char ch = c; - for (; n && *str != ch; --n, ++str) - ; - return n ? const_cast(str) : nullptr; + return internal::memchr(src, c, n); } } // namespace __llvm_libc diff --git a/libc/src/string/string_utils.h b/libc/src/string/string_utils.h --- a/libc/src/string/string_utils.h +++ b/libc/src/string/string_utils.h @@ -15,6 +15,8 @@ namespace __llvm_libc { namespace internal { +template constexpr T Max(T a, T b) { return b < a ? b : a; } + // Returns the maximum length span that contains only characters not found in // 'segment'. If no characters are found, returns the length of 'src'. static inline size_t complementary_span(const char *src, const char *segment) { @@ -62,6 +64,21 @@ return token; } +// TODO: Remove this when __llvm_libc::memcmp is implemented. +static inline int memcmp(const unsigned char *left, const unsigned char *right, + size_t n) { + for (; n && *left == *right; --n, ++left, ++right) + ; + return n ? *left - *right : 0; +} + +static inline void *memchr(const void *src, unsigned char ch, size_t n) { + const unsigned char *str = reinterpret_cast(src); + for (; n && *str != ch; --n, ++str) + ; + return n ? const_cast(str) : nullptr; +} + } // namespace internal } // namespace __llvm_libc diff --git a/libc/src/string/strstr.cpp b/libc/src/string/strstr.cpp --- a/libc/src/string/strstr.cpp +++ b/libc/src/string/strstr.cpp @@ -9,16 +9,172 @@ #include "src/string/strstr.h" #include "src/__support/common.h" +#include "src/string/string_utils.h" +#include "utils/CPP/Bitset.h" #include namespace __llvm_libc { -// TODO: This is a simple brute force implementation. This can be -// improved upon using well known string matching algorithms. -char *LLVM_LIBC_ENTRYPOINT(strstr)(const char *haystack, const char *needle) { +// Used to store the factorization for a given string. +struct Factorization { + // The maximal suffix start index for a factorization. + size_t MaximalSuffixIndex; + // The period of the string. P is a period of a string X if two letters at + // distance P always coincide. + size_t Period; + // Since a mismatch will only advance by 'Period' characters, we use + // this to avoid rescanning known occurrences of the period. + size_t InitialMemory; +}; + +constexpr bool GreaterThan(const unsigned char a, const unsigned char b) { + return a > b; +} + +constexpr bool LessThan(const unsigned char a, const unsigned char b) { + return a < b; +} + +// Computes the maximal suffix and period of 'str' with the given length 'len'. +// This is generalized for both different comparison cases, since both '<' and +// '>' lexographical comparisons are necessary to compute the critical +// factorization. +template +static inline Factorization compute_maximal_suffix(const unsigned char *str, + const size_t len, + ComparisonFunction compare) { + int maximal_suffix_index = -1; + size_t test_index = 0; + size_t k = 1; + size_t period = 1; + while (test_index + k < len) { + if (str[maximal_suffix_index + k] == str[test_index + k]) { + if (k == period) { + test_index += period; + k = 1; + } else { + ++k; + } + } else if (compare(str[maximal_suffix_index + k], str[test_index + k])) { + test_index += k; + k = 1; + period = test_index - maximal_suffix_index; + } else { + maximal_suffix_index = test_index; + ++test_index; + k = period = 1; + } + } + Factorization f; + f.MaximalSuffixIndex = internal::Max(0, maximal_suffix_index); + f.Period = period; + return f; +} + +// Computes the critical factorization of 'str' with length 'len'. +// A critical factorization is a factorization (u, v) of a string X such that +// the repetition of (u, v) is equal to the period of X. This is computed in +// 2 * 'len' comparisons. Also determines if the entirety of 'str' is periodic, +// and uses Factorization.InitialMemory to avoid rescans in this case. +static inline Factorization critical_factorization(const unsigned char *str, + size_t len) { + const Factorization LT = compute_maximal_suffix(str, len, LessThan); + const Factorization GT = compute_maximal_suffix(str, len, GreaterThan); + Factorization f = LT.MaximalSuffixIndex > GT.MaximalSuffixIndex ? LT : GT; + + // Determine if entire string is periodic. + if (internal::memcmp(str, str + f.Period, f.MaximalSuffixIndex + 1)) { + f.InitialMemory = 0; + f.Period = + internal::Max(f.MaximalSuffixIndex, len - f.MaximalSuffixIndex - 1) + 1; + } else { + f.InitialMemory = len - f.Period; + } + return f; +} + +// Uses the two-way string matching algorithm, in conjunction with a +// bad character shift table, as first described in Boyer-Moore. The +// algorithm guarantees linear time complexity and constant space +// complexity. +// +// References: +// 1. Two-Way: +// https://en.wikipedia.org/wiki/Two-way_string-matching_algorithm +// 2. Boyer-Moore: +// https://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string-search_algorithm +static inline char *two_way_string_match(const char *h, const char *n) { + const unsigned char *haystack = reinterpret_cast(h); + const unsigned char *needle = reinterpret_cast(n); + cpp::Bitset<256> needle_set; + unsigned char shift_table[256]; + + size_t needle_len = 0; + for (; needle[needle_len] && haystack[needle_len]; ++needle_len) { + const unsigned char ch = needle[needle_len]; + needle_set.set(ch); + shift_table[ch] = needle_len + 1; + } + if (needle[needle_len] != '\0') // haystack length < needle length. + return nullptr; + + const Factorization f = critical_factorization(needle, needle_len); + + const unsigned char *end_of_haystack = haystack; + size_t memory = 0; + for (;;) { + if (static_cast(end_of_haystack - haystack) < needle_len) { + const auto *lengthened = reinterpret_cast( + internal::memchr(end_of_haystack, '\0', needle_len)); + if (lengthened != nullptr) { + if (static_cast(lengthened - haystack) < needle_len) + return nullptr; + end_of_haystack = lengthened; + } else { + end_of_haystack += needle_len; + } + } + const size_t last_byte = haystack[needle_len - 1]; + if (needle_set.test(last_byte)) { + const size_t bad_character_shift = needle_len - shift_table[last_byte]; + if (bad_character_shift > 0) { + haystack += internal::Max(bad_character_shift, memory); + memory = 0; + continue; + } + } else { + haystack += needle_len; + memory = 0; + continue; + } + // Compare right half of the factorization. + size_t k = internal::Max(f.MaximalSuffixIndex + 1, memory); + for (; needle[k] && needle[k] == haystack[k]; ++k) + ; + if (needle[k] != '\0') { + haystack += k - f.MaximalSuffixIndex; + memory = 0; + continue; + } + // Compare left half of the factorization. + k = f.MaximalSuffixIndex + 1; + for (; k > memory && needle[k - 1] == haystack[k - 1]; --k) + ; + if (k <= memory) + return reinterpret_cast(const_cast(haystack)); + + haystack += f.Period; + memory = f.InitialMemory; + } +} + +// A brute force string matching algorithm, that has O(m * n) complexity, +// where m is the length of the haystack and n is the length of the needle. +constexpr char *brute_force_string_match(const char *haystack, + const char *needle) { for (size_t i = 0; haystack[i]; ++i) { - size_t j; - for (j = 0; haystack[i + j] && haystack[i + j] == needle[j]; ++j) + size_t j = 0; + for (; haystack[i + j] && haystack[i + j] == needle[j]; ++j) ; if (!needle[j]) return const_cast(haystack + i); @@ -26,4 +182,16 @@ return nullptr; } +char *LLVM_LIBC_ENTRYPOINT(strstr)(const char *haystack, const char *needle) { + if (*needle == '\0') // Empty needle returns the entire haystack. + return const_cast(haystack); + + constexpr size_t TwoWayMatchingThreshold = 64; + for (size_t i = 0; i < TwoWayMatchingThreshold; ++i) { + if (haystack[i] == '\0') + return brute_force_string_match(haystack, needle); + } + return two_way_string_match(haystack, needle); +} + } // namespace __llvm_libc diff --git a/libc/test/src/string/CMakeLists.txt b/libc/test/src/string/CMakeLists.txt --- a/libc/test/src/string/CMakeLists.txt +++ b/libc/test/src/string/CMakeLists.txt @@ -70,6 +70,7 @@ strstr_test.cpp DEPENDS libc.src.string.strstr + libc.utils.CPP.standalone_cpp ) add_libc_unittest( diff --git a/libc/test/src/string/strstr_test.cpp b/libc/test/src/string/strstr_test.cpp --- a/libc/test/src/string/strstr_test.cpp +++ b/libc/test/src/string/strstr_test.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "src/string/strstr.h" +#include "utils/CPP/Array.h" #include "utils/UnitTest/Test.h" +#include TEST(StrStrTest, NeedleNotInHaystack) { const char *haystack = "12345"; @@ -33,14 +35,6 @@ ASSERT_STREQ(__llvm_libc::strstr(haystack, needle), ""); } -TEST(StrStrTest, HaystackAndNeedleAreSingleCharacters) { - const char *haystack = "a"; - // Same characer returns that character. - ASSERT_STREQ(__llvm_libc::strstr(haystack, /*needle=*/"a"), "a"); - // Different character returns nullptr. - ASSERT_STREQ(__llvm_libc::strstr(haystack, /*needle=*/"b"), nullptr); -} - TEST(StrStrTest, NeedleEqualToHaystack) { const char *haystack = "12345"; const char *needle = "12345"; @@ -85,10 +79,14 @@ } TEST(StrStrTest, PartialNeedle) { - const char *haystack = "la_ap_lap"; const char *needle = "lap"; - // Shouldn't find la or ap. - ASSERT_STREQ(__llvm_libc::strstr(haystack, needle), "lap"); + // Shouldn't find partials in the beginning. + ASSERT_STREQ(__llvm_libc::strstr( + "al_la_alal_lala_lllllll_aaaaaaaa_ppppp_la_ap_lap", needle), + "lap"); + // Same if needle is at the beginning. + const char *haystack = "lap_al_la_alal_lala_lllllll_aaaaaaaa_ppppp_la_ap_lap"; + ASSERT_STREQ(__llvm_libc::strstr(haystack, needle), haystack); } TEST(StrStrTest, MisspelledNeedle) { @@ -98,7 +96,7 @@ } TEST(StrStrTest, AnagramNeedle) { - const char *haystack = "dgo_ogd_god_odg_gdo_dog"; + const char *haystack = "???_123_fff_ooo_ddd_dgo_ogd_god_odg_gdo_dog"; const char *needle = "dog"; ASSERT_STREQ(__llvm_libc::strstr(haystack, needle), "dog"); } @@ -112,3 +110,79 @@ ASSERT_STREQ(__llvm_libc::strstr(haystack, "tire"), nullptr); ASSERT_STREQ(__llvm_libc::strstr(haystack, "timo"), nullptr); } + +TEST(StrStrTest, NeedleIsSingleByte) { + // Different character returns nullptr. + ASSERT_STREQ(__llvm_libc::strstr("a", "b"), nullptr); + ASSERT_STREQ(__llvm_libc::strstr("w", "v"), nullptr); + + const char *haystack = "abcdefghijklmnopqrstuvwxyz"; + const __llvm_libc::cpp::Array needles = { + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", + "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"}; + for (int i = 0; i < 26; ++i) { + ASSERT_STREQ(__llvm_libc::strstr(haystack, needles[i].data()), + haystack + i); + } +} + +TEST(StrStrTest, NeedleIsTwoBytes) { + const char *haystack = "abcdefghijklmnopqrstuvwxyz"; + const __llvm_libc::cpp::Array needles = { + "ab", "bc", "cd", "de", "ef", "fg", "gh", "hi", "ij", + "jk", "kl", "lm", "mn", "no", "op", "pq", "qr", "rs", + "st", "tu", "uv", "vw", "wx", "xy", "yz"}; + for (int i = 0; i < 25; ++i) { + ASSERT_STREQ(__llvm_libc::strstr(haystack, needles[i].data()), + haystack + i); + } +} + +TEST(StrStrTest, NeedleIsThreeBytes) { + const char *haystack = "abcdefghijklmnopqrstuvwxyz"; + const __llvm_libc::cpp::Array needles = { + "abc", "bcd", "cde", "def", "efg", "fgh", "ghi", "hij", + "ijk", "jkl", "klm", "lmn", "mno", "nop", "opq", "pqr", + "qrs", "rst", "stu", "tuv", "uvw", "vwx", "wxy", "xyz"}; + for (int i = 0; i < 24; ++i) { + ASSERT_STREQ(__llvm_libc::strstr(haystack, needles[i].data()), + haystack + i); + } +} + +TEST(StrStrTest, NeedleIsFourBytes) { + const char *haystack = "abcdefghijklmnopqrstuvwxyz"; + const __llvm_libc::cpp::Array needles = { + "abcd", "bcde", "cdef", "defg", "efgh", "fghi", "ghij", "hijk", + "ijkl", "jklm", "klmn", "lmno", "mnop", "nopq", "opqr", "pqrs", + "qrst", "rstu", "stuv", "tuvw", "uvwx", "vwxy", "wxyz"}; + for (int i = 0; i < 23; ++i) { + ASSERT_STREQ(__llvm_libc::strstr(haystack, needles[i].data()), + haystack + i); + } +} + +TEST(StrStrTest, NeedleIsFiveBytes) { + const char *haystack = "abcdefghijklmnopqrstuvwxyz"; + const __llvm_libc::cpp::Array needles = { + "abcde", "bcdef", "cdefg", "defgh", "efghi", "fghij", "ghijk", "hijkl", + "ijklm", "jklmn", "klmno", "lmnop", "mnopq", "nopqr", "opqrs", "pqrst", + "qrstu", "rstuv", "stuvw", "tuvwx", "uvwxy", "vwxyz"}; + for (int i = 0; i < 22; ++i) { + ASSERT_STREQ(__llvm_libc::strstr(haystack, needles[i].data()), + haystack + i); + } +} + +TEST(StrStrTest, LongNeedle) { + const char *haystack = "0123456789" + "XXXXXXXXXXXXXXXX" + "XXXXXXXXXXXXXXXX" + "XXXXXXXXXXXXXXXX" + "XXXXXXXXXXXXXXXX"; + const char *needle = "XXXXXXXXXXXXXXXX" // 64 bytes + "XXXXXXXXXXXXXXXX" + "XXXXXXXXXXXXXXXX" + "XXXXXXXXXXXXXXXX"; + ASSERT_STREQ(__llvm_libc::strstr(haystack, needle), haystack + 10); +}