diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt --- a/libc/config/linux/x86_64/entrypoints.txt +++ b/libc/config/linux/x86_64/entrypoints.txt @@ -26,6 +26,14 @@ libc.src.sched.sched_getaffinity libc.src.sched.sched_setaffinity + # search.h entrypoints + libc.src.search.hcreate + libc.src.search.hcreate_r + libc.src.search.hdestroy + libc.src.search.hdestroy_r + libc.src.search.hsearch + libc.src.search.hsearch_r + # string.h entrypoints libc.src.string.bcmp libc.src.string.bzero diff --git a/libc/config/linux/x86_64/headers.txt b/libc/config/linux/x86_64/headers.txt --- a/libc/config/linux/x86_64/headers.txt +++ b/libc/config/linux/x86_64/headers.txt @@ -10,6 +10,7 @@ libc.include.pthread libc.include.sched libc.include.signal + libc.include.search libc.include.spawn libc.include.stdio libc.include.stdlib diff --git a/libc/include/CMakeLists.txt b/libc/include/CMakeLists.txt --- a/libc/include/CMakeLists.txt +++ b/libc/include/CMakeLists.txt @@ -226,6 +226,18 @@ .llvm-libc-types.posix_spawn_file_actions_t ) +add_gen_header( + search + DEF_FILE search.h.def + GEN_HDR search.h + DEPENDS + .llvm_libc_common_h + .llvm-libc-types.action + .llvm-libc-types.struct_hsearch_data + .llvm-libc-types.entry + .llvm-libc-types.size_t +) + # TODO: Not all platforms will have a include/sys directory. Add the sys # directory and the targets for sys/*.h files conditional to the OS requiring # them. diff --git a/libc/include/llvm-libc-types/CMakeLists.txt b/libc/include/llvm-libc-types/CMakeLists.txt --- a/libc/include/llvm-libc-types/CMakeLists.txt +++ b/libc/include/llvm-libc-types/CMakeLists.txt @@ -79,3 +79,6 @@ add_header(speed_t HDR speed_t.h) add_header(tcflag_t HDR tcflag_t.h) add_header(struct_termios HDR struct_termios.h DEPENDS .cc_t .speed_t .tcflag_t) +add_header(action HDR action.h) +add_header(entry HDR entry.h) +add_header(struct_hsearch_data HDR struct_hsearch_data.h) diff --git a/libc/include/llvm-libc-types/action.h b/libc/include/llvm-libc-types/action.h new file mode 100644 --- /dev/null +++ b/libc/include/llvm-libc-types/action.h @@ -0,0 +1,14 @@ +//===-- Definition of type enum ACTION ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef __LLVM_LIBC_TYPES_ENUM_ACTION_H__ +#define __LLVM_LIBC_TYPES_ENUM_ACTION_H__ + +typedef enum { FIND, ENTER } ACTION; + +#endif // __LLVM_LIBC_TYPES_ENUM_ACTION_H__ diff --git a/libc/include/llvm-libc-types/entry.h b/libc/include/llvm-libc-types/entry.h new file mode 100644 --- /dev/null +++ b/libc/include/llvm-libc-types/entry.h @@ -0,0 +1,17 @@ +//===-- Definition of type struct ENTRY -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef __LLVM_LIBC_TYPES_STRUCT_ENTRY_H__ +#define __LLVM_LIBC_TYPES_STRUCT_ENTRY_H__ + +typedef struct { + char *key; + void *data; +} ENTRY; + +#endif // __LLVM_LIBC_TYPES_STRUCT_ENTRY_H__ diff --git a/libc/include/llvm-libc-types/struct_hsearch_data.h b/libc/include/llvm-libc-types/struct_hsearch_data.h new file mode 100644 --- /dev/null +++ b/libc/include/llvm-libc-types/struct_hsearch_data.h @@ -0,0 +1,17 @@ +//===-- Definition of type struct hsearch_data ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef __LLVM_LIBC_TYPES_STRUCT_HSEARCH_DATA_H__ +#define __LLVM_LIBC_TYPES_STRUCT_HSEARCH_DATA_H__ + +struct hsearch_data { + void *__opaque; + unsigned int __unused[2]; +}; + +#endif // __LLVM_LIBC_TYPES_STRUCT_HSEARCH_DATA_H__ diff --git a/libc/include/search.h.def b/libc/include/search.h.def new file mode 100644 --- /dev/null +++ b/libc/include/search.h.def @@ -0,0 +1,20 @@ +//===-- POSIX header search.h ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SEARCH_H +#define LLVM_LIBC_SEARCH_H + +#include <__llvm-libc-common.h> +#include +#include +#include +#include + +%%public_api() + +#endif // LLVM_LIBC_SEARCH_H diff --git a/libc/spec/posix.td b/libc/spec/posix.td --- a/libc/spec/posix.td +++ b/libc/spec/posix.td @@ -967,6 +967,71 @@ ] >; + NamedType StructHsearchDataName = NamedType<"struct hsearch_data">; + PtrType StructHsearchDataNamePtr = PtrType; + NamedType ActionType = NamedType<"ACTION">; + NamedType EntryType = NamedType<"ENTRY">; + PtrType EntryTypePtr = PtrType; + PtrType EntryTypePtrPtr = PtrType; + + HeaderSpec Search = HeaderSpec< + "search.h", + [], // Macros + [ + StructHsearchDataName, + ActionType, + EntryType + ], // Types + [], // Enumerations + [ + FunctionSpec< + "hcreate", + RetValSpec, + [ + ArgSpec + ] + >, + FunctionSpec< + "hcreate_r", + RetValSpec, + [ + ArgSpec, + ArgSpec + ] + >, + FunctionSpec< + "hdestroy", + RetValSpec, + [] // Args + >, + FunctionSpec< + "hdestroy_r", + RetValSpec, + [ + ArgSpec + ] + >, + FunctionSpec< + "hsearch", + RetValSpec, + [ + ArgSpec, + ArgSpec + ] + >, + FunctionSpec< + "hsearch_r", + RetValSpec, + [ + ArgSpec, + ArgSpec, + ArgSpec, + ArgSpec + ] + >, + ] + >; + HeaderSpec StdIO = HeaderSpec< "stdio.h", [], // Macros @@ -1192,6 +1257,7 @@ PThread, Signal, Spawn, + Search, StdIO, StdLib, SysIOctl, diff --git a/libc/src/CMakeLists.txt b/libc/src/CMakeLists.txt --- a/libc/src/CMakeLists.txt +++ b/libc/src/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(string) add_subdirectory(stdlib) add_subdirectory(stdio) +add_subdirectory(search) if(${LIBC_TARGET_OS} STREQUAL "linux") add_subdirectory(dirent) diff --git a/libc/src/__support/CMakeLists.txt b/libc/src/__support/CMakeLists.txt --- a/libc/src/__support/CMakeLists.txt +++ b/libc/src/__support/CMakeLists.txt @@ -129,6 +129,29 @@ .uint ) +add_header_library( + swisstable + HDRS + swisstable.h + DEPENDS + .common + .builtin_wrappers + libc.src.__support.CPP.limits + libc.src.string.memcpy + libc.src.string.memset + libc.include.stdlib +) + +add_header_library( + wyhash + HDRS + wyhash.h + DEPENDS + .common + .uint128 + libc.src.string.memcpy +) + add_subdirectory(FPUtil) add_subdirectory(OSUtil) add_subdirectory(StringUtil) diff --git a/libc/src/__support/builtin_wrappers.h b/libc/src/__support/builtin_wrappers.h --- a/libc/src/__support/builtin_wrappers.h +++ b/libc/src/__support/builtin_wrappers.h @@ -26,6 +26,14 @@ } template static inline int clz(T val); +template <> inline int clz(unsigned short val) { +#if __has_builtin(__builtin_clzs) + return __builtin_clzs(val); +#else + return __builtin_clz(static_cast(val)) - + 8 * (sizeof - sizeof); +#endif +} template <> inline int clz(unsigned int val) { return __builtin_clz(val); } @@ -37,6 +45,13 @@ } template static inline int ctz(T val); +template <> inline int ctz(unsigned short val) { +#if __has_builtin(__builtin_ctzs) + return __builtin_ctzs(val); +#else + return __builtin_ctzs(static_cast(val)); +#endif +} template <> inline int ctz(unsigned int val) { return __builtin_ctz(val); } diff --git a/libc/src/__support/swisstable.h b/libc/src/__support/swisstable.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/swisstable.h @@ -0,0 +1,619 @@ +//===-- SwissTable ----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/CPP/array.h" +#include "src/__support/CPP/limits.h" +#include "src/__support/builtin_wrappers.h" +#include "src/__support/common.h" +#include "src/__support/swisstable/dispatch.h" +#include "src/string/memcpy.h" +#include "src/string/memory_utils/memcpy_implementations.h" +#include "src/string/memset.h" +#include + +namespace __llvm_libc::cpp::swisstable { +static inline bool is_full(CtrlWord ctrl) { return (ctrl & 0x80) == 0; } +static inline bool special_is_empty(CtrlWord ctrl) { + return (ctrl & 0x01) != 0; +} +static inline size_t h1(HashWord hash) { return static_cast(hash); } +static inline CtrlWord h2(HashWord hash) { + static constexpr size_t HASH_LENGTH = + sizeof(size_t) < sizeof(HashWord) ? sizeof(size_t) : sizeof(HashWord); + // We want the top 7 bits of the h1. + return (hash >> (HASH_LENGTH * 8 - 7)) & 0x7f; +} + +struct ProbeSequence { + size_t position; + size_t stride; + + void move_next(size_t bucket_mask) { + stride += sizeof(Group); + position += stride; + position &= bucket_mask; + } +}; + +static inline size_t next_power_of_two(size_t val) { + size_t idx = __llvm_libc::unsafe_clz(val - 1); + return 1ull << ((8ull * sizeof(size_t)) - idx); +} + +static inline size_t capacity_to_buckets(size_t cap) { + if (cap < 8) { + return (cap < 4) ? 4 : 8; + } + return next_power_of_two(cap * 8); +} + +static inline size_t bucket_mask_to_capacity(size_t bucket_mask) { + if (bucket_mask < 8) { + return bucket_mask; + } else { + return (bucket_mask + 1) / 8 * 7; + } +} + +template struct ConstantStorage { + alignas(Group) T data[SIZE]; + constexpr ConstantStorage(CtrlWord val) { + for (size_t i = 0; i < SIZE; ++i) { + data[i] = val; + } + } +}; + +constexpr static inline ConstantStorage + CONST_EMPTY_GROUP = {EMPTY}; + +// The heap memory layout for N buckets of size S: +// +// ========================================= +// Fields: | buckets | ctrl bytes | group | +// ----------------------------------------- +// Size: | N*S | N | sizeof(Group) | +// ========================================= +// ^ +// | +// Store this position in RawTable. +// +// The trailing group part is to make sure we can always load +// a whole group of control bytes. + +// Limit memory size to the max of ssize_t +class SafeMemSize { +private: + static constexpr size_t MAX_MEM_SIZE = + static_cast(numeric_limits::max()); + ssize_t value; + explicit SafeMemSize(ssize_t value) : value(value) {} + +public: + explicit SafeMemSize(size_t value) + : value(value <= MAX_MEM_SIZE ? static_cast(value) : -1) {} + operator size_t() { return static_cast(value); } + bool valid() { return value >= 0; } + SafeMemSize operator+(const SafeMemSize &other) { + ssize_t result; + if (unlikely((value | other.value) < 0)) + result = -1; + result = value + other.value; + return SafeMemSize{result}; + } + SafeMemSize operator*(const SafeMemSize &other) { + ssize_t result; + if (unlikely((value | other.value) < 0)) + result = -1; + if (unlikely(__builtin_mul_overflow(value, other.value, &result))) + result = -1; + return SafeMemSize{result}; + } +}; + +template struct TableLayout { + // Pick the largest alignment between T and Group. + static constexpr inline size_t ALIGNMENT = alignof(T) > alignof(Group) + ? alignof(T) + : alignof(Group); + size_t offset; + size_t size; + + TableLayout(size_t offset, size_t size) : offset(offset), size(size) {} + + // We want to find an aligned boundary, put buckets to its left and + // put ctrl words to its left. So we just trim the trailing ones from + // (buckets * sizeof(T) + alignment - 1) + static TableLayout checked(size_t buckets, bool &valid) { + valid = true; + SafeMemSize padded = SafeMemSize{buckets} * SafeMemSize{sizeof(T)} + + SafeMemSize{ALIGNMENT - 1}; + valid = valid && padded.valid(); + size_t offset = static_cast(padded) & ~(ALIGNMENT - 1); + + SafeMemSize safe_size = + SafeMemSize{offset} + SafeMemSize{buckets} + SafeMemSize{sizeof(Group)}; + valid = valid && safe_size.valid(); + size_t size = static_cast(safe_size); + + return {offset, size}; + } + + static TableLayout unchecked(size_t buckets) { + size_t offset = (buckets * sizeof(T) + ALIGNMENT - 1) & ~(ALIGNMENT - 1); + size_t size = offset + buckets + sizeof(Group); + return {offset, size}; + } +}; + +// We will not consider zero-sized type +template struct Bucket { + static_assert(sizeof(T) > 1, "zero sized type not allowed"); + + T *ptr; + + Bucket(T *base, size_t index) : ptr(base - index) {} + + size_t base_index(T *base) const { return base - ptr; } + + T &operator*() { return ptr[-1]; } + + Bucket operator+(size_t offset) { return {ptr - offset}; } + + bool operator==(const Bucket &other) { return ptr == other.ptr; } + + operator bool() { return ptr != nullptr; } +}; + +// T should be trivial type. +template class RawTable { + using Layout = TableLayout; + // Bucket size is a power of two + size_t bucket_mask; + // Pointer to the control words + CtrlWord *ctrl; + // Number of items before growth is required + size_t growth_left; + // Number of items + size_t items; + +public: + static RawTable create_invalid() { + RawTable table{}; + table.bucket_mask = 0; + table.ctrl = nullptr; + table.growth_left = 0; + table.items = 0; + return table; + } + + static RawTable uninitialized(size_t buckets) { + RawTable table = RawTable::create(); + bool valid; + Layout layout = Layout::checked(buckets, valid); + + if (unlikely(!valid)) + return table; + + CtrlWord *address = + static_cast(aligned_alloc(Layout::ALIGNMENT, layout.size)); + if (unlikely(address == nullptr)) + return table; + + table.bucket_mask = buckets - 1; + table.ctrl = address + layout.offset; + table.growth_left = bucket_mask_to_capacity(table.bucket_mask); + return table; + } + + static RawTable create() { + RawTable table{}; + table.ctrl = const_cast(&CONST_EMPTY_GROUP.data[0]); + table.bucket_mask = 0; + table.growth_left = 0; + table.items = 0; + return table; + } + + static RawTable with_capacity(size_t capacity) { + if (capacity == 0) + return create(); + + // additional check for capacity overflow + auto safe_cap = SafeMemSize{capacity}; + safe_cap = safe_cap * SafeMemSize{size_t{8}}; + if (!safe_cap.valid()) { + return create_invalid(); + } + + size_t buckets = capacity_to_buckets(capacity); + RawTable table = uninitialized(buckets); + if (likely(table.ctrl != nullptr)) + memset(reinterpret_cast(table.ctrl), EMPTY, + table.num_ctrl_words()); + return table; + } + + void release() { + if (!is_empty_singleton()) { + Layout layout = Layout::unchecked(num_buckets()); + free(ctrl - layout.offset); + } + ctrl = nullptr; + } + + bool is_valid() { return ctrl != nullptr; } + +private: + size_t num_buckets() const { return bucket_mask + 1; } + size_t num_ctrl_words() const { return num_buckets() + sizeof(Group); } + + T *data_end() const { return reinterpret_cast(ctrl); } + T *data_begin() const { return reinterpret_cast(ctrl) - num_buckets(); } + + Bucket bucket_at(size_t index) const { return {data_end(), index}; } + size_t bucket_index(const Bucket &bucket) const { + return bucket.base_index(data_end()); + } + + // Sets a control byte, and possibly also the replicated control byte at + // the end of the array. + void set_ctrl(size_t index, CtrlWord value) { + // Replicate the first sizeof(Group) control bytes at the end of + // the array without using a branch: + // - If index >= sizeof(Group) then index == index2. + // - Otherwise index2 == self.bucket_mask + 1 + index. + // + // The very last replicated control byte is never actually read because + // we mask the initial index for unaligned loads, but we write it + // anyways because it makes the set_ctrl implementation simpler. + // + // If there are fewer buckets than sizeof(Group) then this code will + // replicate the buckets at the end of the trailing group. For example + // with 2 buckets and a group size of 4, the control bytes will look + // like this: + // ============================================= + // | Real | Replicated | + // --------------------------------------------- + // | [A] | [B] | [EMPTY] | [EMPTY] | [A] | [B] | + // ============================================= + + size_t index2 = ((index - sizeof(Group)) & bucket_mask) + sizeof(Group); + ctrl[index] = value; + ctrl[index2] = value; + } + + ProbeSequence probe_sequence(HashWord hash) const { + return {h1(hash) & bucket_mask, 0}; + } + + size_t proper_insertion_slot(size_t index) const { + // In tables smaller than the group width, trailing control + // bytes outside the range of the table are filled with + // EMPTY entries. These will unfortunately trigger a + // match, but once masked may point to a full bucket that + // is already occupied. We detect this situation here and + // perform a second scan starting at the beginning of the + // table. This second scan is guaranteed to find an empty + // slot (due to the load factor) before hitting the trailing + // control bytes (containing EMPTY). + if (unlikely(is_full(ctrl[index]))) { + return Group::aligned_load(ctrl) + .mask_empty_or_deleted() + .lowest_set_bit_nonzero(); + } + return index; + } + + size_t find_insert_slot(HashWord hash) const { + ProbeSequence seq = probe_sequence(hash); + while (true) { + Group group = Group::load(&ctrl[seq.position]); + BitMask empty_slot = group.mask_empty_or_deleted(); + if (empty_slot.any_bit_set()) { + size_t result = + (seq.position + empty_slot.lowest_set_bit_nonzero()) & bucket_mask; + return proper_insertion_slot(result); + } + seq.move_next(bucket_mask); + } + } + + template + Bucket find_or_insert_with_hash(HashWord hash, const T &value, Equal eq, + Hasher hasher) { + CtrlWord h2_hash = h2(hash); + ProbeSequence seq = probe_sequence(hash); + while (true) { + Group group = Group::load(&ctrl[seq.position]); + for (size_t bit : group.match_byte(h2_hash)) { + size_t index = (seq.position + bit) & bucket_mask; + auto bucket = bucket_at(index); + if (likely(eq(*bucket, value))) + return bucket; + } + + if constexpr (ENABLE_DELETION) { + BitMask empty_slot = group.mask_empty(); + if (likely(empty_slot.any_bit_set())) { + size_t index = find_insert_slot(hash); + return insert_at(index, hash, value, hasher); + } + } else { + BitMask empty_slot = group.mask_empty_or_deleted(); + if (likely(empty_slot.any_bit_set())) { + size_t index = (seq.position + empty_slot.lowest_set_bit_nonzero()) & + bucket_mask; + index = proper_insertion_slot(index); + return insert_at(index, hash, value, hasher); + } + } + + seq.move_next(bucket_mask); + } + } + + template + Bucket find_with_hash(HashWord hash, const T &value, Equal eq) const { + CtrlWord h2_hash = h2(hash); + ProbeSequence seq = probe_sequence(hash); + while (true) { + Group group = Group::load(&ctrl[seq.position]); + for (size_t bit : group.match_byte(h2_hash)) { + size_t index = (seq.position + bit) & bucket_mask; + auto bucket = bucket_at(index); + if (likely(eq(*bucket, value))) + return bucket; + } + + if constexpr (ENABLE_DELETION) { + if (likely(group.mask_empty().any_bit_set())) + return {nullptr, 0}; + } else { + // Only EMPTY will appear; no need to distingush EMPTY and DELETED. + if (likely(group.mask_empty_or_deleted().any_bit_set())) + return {nullptr, 0}; + } + + seq.move_next(bucket_mask); + } + } + + void set_ctrl_h2(size_t index, HashWord hash) { set_ctrl(index, h2(hash)); } + + CtrlWord replace_ctrl_h2(size_t index, HashWord hash) { + CtrlWord prev = ctrl[index]; + set_ctrl_h2(index, hash); + return prev; + } + + bool is_bucket_full(size_t index) const { return is_full(ctrl[index]); } + + struct Slot { + size_t index; + CtrlWord prev_ctrl; + }; + + Slot prepare_insert_slot(HashWord hash) { + size_t index = find_insert_slot(hash); + CtrlWord prev_ctrl = ctrl[index]; + set_ctrl_h2(index, hash); + return {index, prev_ctrl}; + } + + void prepare_rehash_inplace() { + // convert full to deleted, deleted to empty s.t. we can use + // deleted as an indicator for rehash + for (size_t i = 0; i < num_buckets(); i += sizeof(Group)) { + Group group = Group::aligned_load(&ctrl[i]); + Group converted = group.convert_special_to_empty_and_full_to_deleted(); + converted.aligned_store(&ctrl[i]); + } + + // handle the cases when table size is smaller than group size + if (num_buckets() < sizeof(Group)) { + memcpy(&ctrl[sizeof(Group)], &ctrl[0], num_buckets()); + } else { + inline_memcpy(reinterpret_cast(&ctrl[num_buckets()]), + reinterpret_cast(&ctrl[0]), sizeof(Group)); + } + } + + void record_item_insert_at(size_t index, CtrlWord prev_ctrl, HashWord hash) { + growth_left -= special_is_empty(prev_ctrl) ? 1 : 0; + set_ctrl_h2(index, hash); + items++; + } + + bool is_in_same_group(size_t index, size_t new_index, HashWord hash) const { + size_t probe_position = probe_sequence(hash).position; + size_t position[2] = { + (index - probe_position) & bucket_mask / sizeof(Group), + (new_index - probe_position) & bucket_mask / sizeof(Group), + }; + return position[0] == position[1]; + } + + bool is_empty_singleton() const { return ctrl == CONST_EMPTY_GROUP.data; } + + RawTable prepare_resize(size_t capacity) { + RawTable new_table = RawTable::with_capacity(capacity); + if (likely(new_table.ctrl != nullptr)) { + new_table.growth_left -= items; + new_table.items += items; + } + return new_table; + } + + template void rehash_in_place(Hasher hasher) { + prepare_rehash_inplace(); + + for (size_t idx = 0; idx < num_buckets(); ++idx) { + if (ctrl[idx] != DELETED) { + continue; + } + + Bucket bucket = bucket_at(idx); + + while (true) { + HashWord hash = hasher(*bucket_at(idx)); + size_t new_idx = find_insert_slot(hash); + // Probing works by scanning through all of the control + // bytes in groups, which may not be aligned to the group + // size. If both the new and old position fall within the + // same unaligned group, then there is no benefit in moving + // it and we can just continue to the next item. + if (likely(is_in_same_group(idx, new_idx, hash))) { + set_ctrl_h2(idx, hash); + break; // continue outer loop + } + + Bucket new_bucket = bucket_at(new_idx); + CtrlWord prev_ctrl = replace_ctrl_h2(new_idx, hash); + + if (prev_ctrl == EMPTY) { + set_ctrl(idx, EMPTY); + *new_bucket = *bucket; + break; // continue outer loop + } + + T temp; + temp = *new_bucket; + *new_bucket = *bucket; + *bucket = temp; + } + } + + growth_left = bucket_mask_to_capacity(bucket_mask) - items; + } + + template bool resize(size_t new_capacity, Hasher hasher) { + RawTable new_table = prepare_resize(new_capacity); + if (new_table.ctrl == nullptr) + return false; + + for (size_t i = 0; i < num_buckets(); ++i) { + if (!is_bucket_full(i)) + continue; + + Bucket bucket = bucket_at(i); + HashWord hash = hasher(*bucket); + + // We can use a simpler version of insert() here since: + // - there are no DELETED entries. + // - we know there is enough space in the table. + // - all elements are unique. + Slot slot = new_table.prepare_insert_slot(hash); + Bucket new_bucket = new_table.bucket_at(slot.index); + + *new_bucket = *bucket; + } + + release(); + *this = new_table; + + return true; + } + + template + bool reserve_rehash(size_t additional, Hasher hasher) { + SafeMemSize checked_new_items = + SafeMemSize{additional} + SafeMemSize{items}; + if (!checked_new_items.valid()) + return false; + + size_t new_items = static_cast(checked_new_items); + size_t full_capacity = bucket_mask_to_capacity(bucket_mask); + + if constexpr (ENABLE_DELETION) { + if (new_items <= full_capacity / 2) { + rehash_in_place(hasher); + return true; + } + } + + if constexpr (ENABLE_RESIZE) { + size_t new_capacity = + new_items > (full_capacity + 1) ? new_items : (full_capacity + 1); + return resize(new_capacity, hasher); + } + + return false; + } + + template + Bucket insert_at(size_t index, HashWord hash, const T &element, + Hasher hasher) { + size_t prev_ctrl = ctrl[index]; + + // If we reach full load factor: + // + // - When deletion is allowed and the found slot is deleted, then it is + // okay to insert. + // - Otherwise, it means it is an empty slot and such insertion will + // invalidate the load factor constraints. Thus, we need to rehash + // the table. There are several cases: + // + // - If deletion is enabled, we may be able to rehash the table in + // place if the total item is less than half of the table. + // - If resizing is enabled, we may be able to grow the table. + // - If either deletion nor resizing is enabled, we will fail + // immediately. + if (unlikely(growth_left == 0 && + (!ENABLE_DELETION || special_is_empty(prev_ctrl)))) { + if (!reserve_rehash(1, hasher)) + return {nullptr, 0}; + index = find_insert_slot(hash); + } + + record_item_insert_at(index, prev_ctrl, hash); + + Bucket bucket = bucket_at(index); + *bucket = element; + + return bucket; + } + +public: + template Bucket insert(const T &element, Hasher hasher) { + HashWord hash = hasher(element); + size_t index = find_insert_slot(hash); + return insert_at(index, hash, element, hasher); + } + + template + Bucket find(const T &element, Hasher hasher, Equal equal) const { + HashWord hash = hasher(element); + return find_with_hash(hash, element, equal); + } + + template + Bucket find_or_insert(const T &element, Hasher hasher, Equal equal) { + HashWord hash = hasher(element); + return find_or_insert_with_hash(hash, element, equal, hasher); + } + + void erase(Bucket item) { + size_t index = bucket_index(item); + size_t index_before = (index - sizeof(Group)) & bucket_mask; + BitMask empty_before = Group::load(&ctrl[index_before]).mask_empty(); + BitMask empty_after = Group::load(&ctrl[index]).mask_empty(); + CtrlWord ctrl; + if (empty_before.leading_zeros() + empty_after.trailing_zeros() >= + sizeof(Group)) { + ctrl = DELETED; + } else { + growth_left++; + ctrl = EMPTY; + }; + set_ctrl(index, ctrl); + items--; + } +}; + +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/src/__support/swisstable/asimd.h b/libc/src/__support/swisstable/asimd.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/swisstable/asimd.h @@ -0,0 +1,86 @@ +//===-- SwissTable ASIMD Specialization -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/swisstable/common.h" +#include + +namespace __llvm_libc::cpp::swisstable { + +// According to abseil-cpp, ARM's 16-byte ASIMD operations may +// introduce too much latency. +// Reference: +// https://github.com/abseil/abseil-cpp/commit/6481443560a92d0a3a55a31807de0cd712cd4f88 +// With ASIMD, some bitmasks are not iteratale. This is because we +// do not want to clear the lower bits in each stride with extra +// `AND` operation. +using BitMask = BitMaskAdaptor; +using IteratableBitMask = IteratableBitMaskAdaptor; + +struct Group { + int8x8_t data; + + // Load a group of control words from an arbitary address. + static Group load(const void *__restrict addr) { + return {vld1_s8(static_cast(addr))}; + } + + // Load a group of control words from an aligned address. + // Notice that there is no difference of aligned/unaigned + // loading in ASIMD. + static Group aligned_load(const void *__restrict addr) { + + return {vld1_s8(static_cast(addr))}; + } + + // Store a group of control words to an aligned address. + void aligned_store(void *addr) const { + vst1_s8(static_cast(addr), data); + } + + // Find out the lanes equal to the given byte and return the bitmask + // with corresponding bits set. + IteratableBitMask match_byte(uint8_t byte) const { + auto duplicated = vdup_n_s8(byte); + auto cmp = vceq_s8(data, duplicated); + auto converted = vget_lane_u64(vreinterpret_u64_u8(cmp), 0); + return {converted & BitMask::MASK}; + } + + // Find out the lanes equal to EMPTY and return the bitmask + // with corresponding bits set. + BitMask mask_empty() const { + auto duplicated = vdup_n_s8(EMPTY); + auto cmp = vceq_s8(data, duplicated); + auto converted = vget_lane_u64(vreinterpret_u64_u8(cmp), 0); + return {converted}; + } + + // Find out the lanes equal to EMPTY or DELETE (highest bit set) and + // return the bitmask with corresponding bits set. + BitMask mask_empty_or_deleted() const { + auto converted = vget_lane_u64(vreinterpret_u64_s8(data), 0); + return {converted & BitMask::MASK}; + } + + // Find out the lanes reprensting full cells (without highest bit) and + // return the bitmask with corresponding bits set. + BitMask mask_full() const { return mask_empty_or_deleted().invert(); } + + // Performs the following transformation on all bytes in the group: + // - `EMPTY => EMPTY` + // - `DELETED => EMPTY` + // - `FULL => DELETED` + Group convert_special_to_empty_and_full_to_deleted() const { + auto dup = vdup_n_s8(0x80); + auto zero = vdup_n_s8(0x00); + auto special = vcgt_s8(zero, data); + return vorr_s8(dup, vreinterpret_s8_u8(special)); + } +}; + +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/src/__support/swisstable/common.h b/libc/src/__support/swisstable/common.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/swisstable/common.h @@ -0,0 +1,89 @@ + +//===-- SwissTable Common Definitions ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include +#include +#include +#include + +namespace __llvm_libc::cpp::swisstable { +// special values of the control byte +using CtrlWord = uint8_t; +using HashWord = uint64_t; +constexpr static inline CtrlWord EMPTY = 0b11111111u; +constexpr static inline CtrlWord DELETED = 0b10000000u; + +// Implementations of the bitmask. +// The backend word type may vary depending on different microarchitectures. +// For example, with X86 SSE2, the bitmask is just the 16bit unsigned integer +// corresponding to lanes in a SIMD register. +template struct BitMaskAdaptor { + // A masked constant whose bits are all set. + constexpr static inline T MASK = WORD_MASK; + // A stride in the bitmask may use multiple bits. + constexpr static inline size_t STRIDE = WORD_STRIDE; + + T word; + + // Invert zeros and ones inside the word. + BitMaskAdaptor invert() const { return {static_cast(word ^ WORD_MASK)}; } + + // Operator helper to do bit manipulations. + BitMaskAdaptor operator^(T value) const { + return {static_cast(this->word ^ value)}; + } + + // Check if any bit is set inside the word. + bool any_bit_set() const { return word != 0; } + + // Count trailing zeros with respect to stride. + size_t trailing_zeros() const { return safe_ctz(word) / WORD_STRIDE; } + + // Count trailing zeros with respect to stride. (Assume the bitmask is none + // zero.) + size_t lowest_set_bit_nonzero() const { + return unsafe_ctz(word) / WORD_STRIDE; + } + + // Count leading zeros with respect to stride. + size_t leading_zeros() const { + // move the word to the highest location. + return safe_clz(word) / WORD_STRIDE; + } +}; + +template struct IteratableBitMaskAdaptor : public BitMask { + // Use the bitmask as an iterator. Update the state and return current lowest + // set bit. To make the bitmask iterable, each stride must contain 0 or exact + // 1 set bit. + void remove_lowest_bit() { + // Remove the last set bit inside the word: + // word = 011110100 (original value) + // word - 1 = 011110011 (invert all bits up to the last set bit) + // word & (word - 1) = 011110000 (value with the last bit cleared) + this->word = this->word & (this->word - 1); + } + using value_type = size_t; + using iterator = BitMask; + using const_iterator = BitMask; + size_t operator*() { return this->lowest_set_bit_nonzero(); } + IteratableBitMaskAdaptor &operator++() { + this->remove_lowest_bit(); + return *this; + } + IteratableBitMaskAdaptor begin() { return *this; } + IteratableBitMaskAdaptor end() { return {0}; } + bool operator==(const IteratableBitMaskAdaptor &other) { + return this->word == other.word; + } + bool operator!=(const IteratableBitMaskAdaptor &other) { + return this->word != other.word; + } +}; + +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/src/__support/swisstable/dispatch.h b/libc/src/__support/swisstable/dispatch.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/swisstable/dispatch.h @@ -0,0 +1,16 @@ +//===-- SwissTable Platform Dispatch ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "src/__support/architectures.h" +#if defined(LLVM_LIBC_ARCH_X86_64) && defined(__SSE2__) +#include "src/__support/swisstable/sse2.h" +#elif defined(LLVM_LIBC_ARCH_ANY_ARM) && defined(__ARM_NEON) && \ + defined(__ORDER_LITTLE_ENDIAN__) +#include "src/__support/swisstable/asimd.h" +#else +#include "src/__support/swisstable/generic.h" +#endif diff --git a/libc/src/__support/swisstable/generic.h b/libc/src/__support/swisstable/generic.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/swisstable/generic.h @@ -0,0 +1,149 @@ +//===-- SwissTable Generic Fallback -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "src/__support/endian.h" +#include "src/__support/swisstable/common.h" +#include "src/string/memory_utils/memcpy_implementations.h" + +namespace __llvm_libc::cpp::swisstable { + +// Helper function to spread a byte across the whole word. +// Accumutively, the procedure looks like: +// byte = 0x00000000000000ff +// byte | (byte << 8) = 0x000000000000ffff +// byte | (byte << 16) = 0x00000000ffffffff +// byte | (byte << 32) = 0xffffffffffffffff +constexpr static inline uintptr_t repeat(uintptr_t byte) { + size_t shift_amount = 8; + while (shift_amount < sizeof(uintptr_t) * 8) { + byte |= byte << shift_amount; + shift_amount <<= 1; + } + return byte; +} + +using BitMask = BitMaskAdaptor; +using IteratableBitMask = IteratableBitMaskAdaptor; + +struct Group { + uintptr_t data; + + // Load a group of control words from an arbitary address. + static Group load(const void *__restrict addr) { + uintptr_t data; + inline_memcpy(reinterpret_cast(&data), + static_cast(addr), sizeof(data)); + return {data}; + } + + // Load a group of control words from an aligned address. + // Notice that there is no difference of aligned/unaigned + // loading in ASIMD. + static Group aligned_load(const void *__restrict addr) { + uintptr_t data = *static_cast(addr); + return {data}; + } + + // Store a group of control words to an aligned address. + void aligned_store(void *addr) const { + *static_cast(addr) = data; + } + + // Find out the lanes equal to the given byte and return the bitmask + // with corresponding bits set. + IteratableBitMask match_byte(uint8_t byte) const { + // Given byte = 0x10, suppose the data is: + // + // data = [ 0x10 | 0x10 | 0x00 | 0xF1 | ... ] + // + // First, we compare the byte using XOR operation: + // + // [ 0x10 | 0x10 | 0x10 | 0x10 | ... ] (0) + // ^ [ 0x10 | 0x10 | 0x00 | 0xF1 | ... ] (1) + // = [ 0x00 | 0x00 | 0x10 | 0xE1 | ... ] (2) + // + // Notice that the equal positions will now be 0x00, so if we substract 0x01 + // respective to every byte, it will need to carry the substraction to upper + // bits (assume no carry from the hidden parts) + // [ 0x00 | 0x00 | 0x10 | 0xE1 | ... ] (2) + // - [ 0x01 | 0x01 | 0x01 | 0x01 | ... ] (3) + // = [ 0xFE | 0xFF | 0x0F | 0xE0 | ... ] (4) + // + // But there may be some bytes whose highest bit is already set after the + // xor operation. To rule out these positions, we AND them with the NOT + // of the XOR result: + // + // [ 0xFF | 0xFF | 0xEF | 0x1E | ... ] (5, NOT (2)) + // & [ 0xFE | 0xFF | 0x0F | 0xE0 | ... ] (4) + // = [ 0xFE | 0xFF | 0x0F | 0x10 | ... ] (6) + // + // To make the bitmask iteratable, only one bit can be set in each stride. + // So we AND each byte with 0x80 and keep only the highest bit: + // + // [ 0xFE | 0xFF | 0x0F | 0x10 | ... ] (6) + // & [ 0x80 | 0x80 | 0x80 | 0x80 | ... ] (7) + // = [ 0x80 | 0x80 | 0x00 | 0x00 | ... ] (8) + // + // However, there are possitbilites for false positives. For example, if the + // data is [ 0x10 | 0x11 | 0x10 | 0xF1 | ... ]. This only happens when there + // is a key only differs from the searched by the lowest bit. The claims + // are: + // + // - This never happens for `EMPTY` and `DELETED`, only full entries. + // - The check for key equality will catch these. + // - This only happens if there is at least 1 true match. + // - The chance of this happening is very low (< 1% chance per byte). + auto cmp = data ^ repeat(byte); + auto result = + Endian::to_little_endian((cmp - repeat(0x01)) & ~cmp & repeat(0x80)); + return {result}; + } + + // Find out the lanes equal to EMPTY and return the bitmask + // with corresponding bits set. + BitMask mask_empty() const { + // If the high bit is set, then the byte must be either: + // 1111_1111 (EMPTY) or 1000_0000 (DELETED). + // So we can just check if the top two bits are 1 by ANDing them. + return {Endian::to_little_endian(data & (data << 1) & repeat(0x80))}; + } + + // Find out the lanes equal to EMPTY or DELETE (highest bit set) and + // return the bitmask with corresponding bits set. + BitMask mask_empty_or_deleted() const { + return {Endian::to_little_endian(data) & repeat(0x80)}; + } + + // Find out the lanes reprensting full cells (without highest bit) and + // return the bitmask with corresponding bits set. + BitMask mask_full() const { return mask_empty_or_deleted().invert(); } + + // Performs the following transformation on all bytes in the group: + // - `EMPTY => EMPTY` + // - `DELETED => EMPTY` + // - `FULL => DELETED` + Group convert_special_to_empty_and_full_to_deleted() const { + // Set the highest bit only for positions whose highest bit is not set + // before. + // + // data = [ 00000000 | 11111111 | 10000000 | ... ] + // ~data = [ 11111111 | 00000000 | 00000000 | ... ] + // full = [ 10000000 | 00000000 | 00000000 | ... ] + + auto full = (~data) & repeat(0x80); + + // Inverse the bit and convert `01111111` to `1000000` by + // add `1` in that bit. The carry will not propogate outside + // that byte: + // ~full = [ 01111111 | 11111111 | 11111111 | ... ] + // full >> 1 = [ 00000001 | 00000000 | 00000000 | ... ] + // result = [ 10000000 | 11111111 | 11111111 | ... ] + return {~full + (full >> 1)}; + } +}; + +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/src/__support/swisstable/sse2.h b/libc/src/__support/swisstable/sse2.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/swisstable/sse2.h @@ -0,0 +1,78 @@ +//===-- SwissTable SSE2 Specialization --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/swisstable/common.h" +#include + +namespace __llvm_libc::cpp::swisstable { + +// With SSE2, every bitmask is iteratable: because +// we use single bit to encode the data. + +using BitMask = BitMaskAdaptor; +using IteratableBitMask = IteratableBitMaskAdaptor; + +struct Group { + __m128i data; + + // Load a group of control words from an arbitary address. + static Group load(const void *__restrict addr) { + return {_mm_loadu_si128(static_cast(addr))}; + } + + // Load a group of control words from an aligned address. + static Group aligned_load(const void *__restrict addr) { + return {_mm_load_si128(static_cast(addr))}; + } + + // Store a group of control words to an aligned address. + void aligned_store(void *addr) const { + _mm_store_si128(static_cast<__m128i *>(addr), data); + } + + // Find out the lanes equal to the given byte and return the bitmask + // with corresponding bits set. + IteratableBitMask match_byte(uint8_t byte) const { + auto cmp = _mm_cmpeq_epi8(data, _mm_set1_epi8(byte)); + auto bitmask = static_cast(_mm_movemask_epi8(cmp)); + return {bitmask}; + } + + // Find out the lanes equal to EMPTY and return the bitmask + // with corresponding bits set. + BitMask mask_empty() const { return match_byte(EMPTY); } + + // Find out the lanes equal to EMPTY or DELETE (highest bit set) and + // return the bitmask with corresponding bits set. + BitMask mask_empty_or_deleted() const { + auto bitmask = static_cast(_mm_movemask_epi8(data)); + return {bitmask}; + } + + // Find out the lanes reprensting full cells (without highest bit) and + // return the bitmask with corresponding bits set. + BitMask mask_full() const { return mask_empty_or_deleted().invert(); } + + // Performs the following transformation on all bytes in the group: + // - `EMPTY => EMPTY` + // - `DELETED => EMPTY` + // - `FULL => DELETED` + Group convert_special_to_empty_and_full_to_deleted() const { + // Recall that EMPTY and DELETED are distinguished from others in + // their highest bit. This makes them negative when considered as + // signed integer. And for full ones, highest bits are all zeros. + // So we first identify those lanes smaller than or equal to zero + // and then convert them by setting the highest bit of them. + __m128i zero = _mm_setzero_si128(); + __m128i special = _mm_cmpgt_epi8(zero, data); + __m128i converted = _mm_or_si128(special, _mm_set1_epi8(0x80u)); + return {converted}; + } +}; + +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/src/__support/wyhash.h b/libc/src/__support/wyhash.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/wyhash.h @@ -0,0 +1,107 @@ +//===-- A 64-bit Hash Function ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/UInt128.h" +#include "src/__support/common.h" +#include "src/__support/endian.h" +#include "src/string/memory_utils/memcpy_implementations.h" + +// WyHash comes from https://github.com/wangyi-fudan/wyhash/ (which has been +// release into public domain). It is also the default hash function +// of Go, Nim and Zig. +// +// According to SMHasher's results, it is one of the fastest hash functions +// without primary quality issues. This hash function can be used with +// swisstable as it has the Avalanche effect to encode the data into full +// 64-bits; thus the table can effectively utilize two level hash technique +// to effectively locate candidate entries without too many false positives. +namespace __llvm_libc::cpp::wyhash { +static constexpr inline uint64_t DEFAULT_SECRET[4] = { + 0xa0761d6478bd642full, 0xe7037ed1a0b428dbull, 0x8ebc6af09c88c6e3ull, + 0x589965cc75374cc3ull}; +template class WyHash { + static void multiply(uint64_t &a, uint64_t &b) { + UInt128 data{a}; + data *= b; + if constexpr (EntropyProtection) { + a ^= static_cast(data); + b ^= static_cast(data >> 64); + } else { + a = static_cast(data); + b = static_cast(data >> 64); + } + } + static uint64_t mix(uint64_t a, uint64_t b) { + multiply(a, b); + return a ^ b; + } + template static uint64_t read(const uint8_t *p) { + uint64_t value; + inline_memcpy(reinterpret_cast(&value), + reinterpret_cast(p), N); + return Endian::to_little_endian(value); + } + + static uint64_t read3(const uint8_t *p, size_t k) { + auto a = static_cast(p[0]) << 16; + auto b = static_cast(p[k / 2]) << 8; + auto c = static_cast(p[k - 1]); + return a | b | c; + } + + static uint64_t wyhash(const void *key, size_t length, uint64_t seed, + const uint64_t *secret) { + const uint8_t *p = static_cast(key); + seed ^= secret[0]; + uint64_t a = 0, b = 0; + if (likely(length <= 16)) { + if (likely(length >= 4)) { + a = (read<4>(p) << 32) | read<4>(p + ((length >> 3) << 2)); + b = (read<4>(p + length - 4) << 32) | + read<4>(p + length - 4 - ((length >> 3) << 2)); + } else if (likely(length > 0)) { + a = read3(p, length); + b = 0; + } + } else { + size_t i = length; + if (likely(i > 48)) { + uint64_t s1 = seed, s2 = seed; + do { + seed = mix(read<8>(p) ^ secret[1], read<8>(p + 8) ^ seed); + s1 = mix(read<8>(p + 16) ^ secret[2], read<8>(p + 24) ^ s1); + s2 = mix(read<8>(p + 32) ^ secret[3], read<8>(p + 40) ^ s2); + p += 48; + i -= 48; + } while (likely(i > 48)); + seed ^= s1 ^ s2; + } + while (unlikely(i > 16)) { + seed = mix(read<8>(p) ^ secret[1], read<8>(p + 8) ^ seed); + i -= 16; + p += 16; + } + a = read<8>(p + i - 16); + b = read<8>(p + i - 8); + } + // Golang's implementation mixes a different secret here, but we keep it the + // same as the original wyhash to make it easier to examine the correctness + // via testing known values. + return mix(secret[1] ^ length, mix(a ^ secret[1], b ^ seed)); + } + +public: + static uint64_t hash(const void *key, size_t length, uint64_t seed) { + return wyhash(key, length, seed, DEFAULT_SECRET); + } +}; + +// Follow the practice in Golang to disable low-entropy protection by default. +using DefaultHash = WyHash; + +} // namespace __llvm_libc::cpp::wyhash diff --git a/libc/src/search/CMakeLists.txt b/libc/src/search/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/libc/src/search/CMakeLists.txt @@ -0,0 +1,68 @@ +add_subdirectory(hashtable) + +add_entrypoint_object( + hsearch + SRCS + hsearch.cpp + HDRS + hsearch.h + DEPENDS + .hashtable.global + .hashtable.search_impl +) + +add_entrypoint_object( + hsearch_r + SRCS + hsearch_r.cpp + HDRS + hsearch_r.h + DEPENDS + .hashtable.search_impl +) + +add_entrypoint_object( + hcreate + SRCS + hcreate.cpp + HDRS + hcreate.h + DEPENDS + .hashtable.global + .hashtable.utils + libc.include.errno +) + +add_entrypoint_object( + hcreate_r + SRCS + hcreate_r.cpp + HDRS + hcreate_r.h + DEPENDS + .hashtable.utils + libc.include.stdlib + libc.include.errno +) + +add_entrypoint_object( + hdestroy + SRCS + hdestroy.cpp + HDRS + hdestroy.h + DEPENDS + .hashtable.global + .hashtable.utils +) + +add_entrypoint_object( + hdestroy_r + SRCS + hdestroy_r.cpp + HDRS + hdestroy_r.h + DEPENDS + .hashtable.utils + libc.include.stdlib +) diff --git a/libc/src/search/hashtable/CMakeLists.txt b/libc/src/search/hashtable/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/libc/src/search/hashtable/CMakeLists.txt @@ -0,0 +1,30 @@ +add_header_library( + utils + HDRS + utils.h + DEPENDS + libc.src.__support.wyhash + libc.src.__support.swisstable + libc.src.string.strcmp + libc.src.string.strlen + libc.include.search +) + +add_object_library( + global + SRCS + global.cpp + HDRS + global.h + DEPENDS + .utils +) + +add_header_library( + search_impl + HDRS + search_impl.h + DEPENDS + .utils + libc.include.errno +) diff --git a/libc/src/search/hashtable/global.h b/libc/src/search/hashtable/global.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hashtable/global.h @@ -0,0 +1,18 @@ +//===-- Implementation Header of Global Hashtable -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SEARCH_HASHTABLE_GLOBAL_H +#define LLVM_LIBC_SRC_SEARCH_HASHTABLE_GLOBAL_H + +#include "src/search/hashtable/utils.h" + +namespace __llvm_libc::search::hashtable { +extern SeededTable global_raw_table; +} // namespace __llvm_libc::search::hashtable + +#endif // LLVM_LIBC_SRC_SEARCH_HASHTABLE_GLOBAL_H diff --git a/libc/src/search/hashtable/global.cpp b/libc/src/search/hashtable/global.cpp new file mode 100644 --- /dev/null +++ b/libc/src/search/hashtable/global.cpp @@ -0,0 +1,13 @@ +//===-- Implementation of Global Hashtable --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/search/hashtable/global.h" + +namespace __llvm_libc::search::hashtable { +SeededTable global_raw_table; +} diff --git a/libc/src/search/hashtable/search_impl.h b/libc/src/search/hashtable/search_impl.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hashtable/search_impl.h @@ -0,0 +1,45 @@ +//===-- Implementation Header of Hashtable Search ------------------- ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_LIBC_SRC_SEARCH_HASHTABLE_SEARCH_IMPL_H +#define LLVM_LIBC_SRC_SEARCH_HASHTABLE_SEARCH_IMPL_H + +#include "src/search/hashtable/utils.h" +#include +namespace __llvm_libc::search::hashtable { + +static inline int search_impl(ENTRY item, ACTION action, ENTRY **retval, + SeededTable *table) { + using namespace search; + Hash hasher{table->seed}; + switch (action) { + case ENTER: { + auto bucket = table->raw.find_or_insert(item, hasher, equal); + if (!bucket) { + *retval = nullptr; + errno = ENOMEM; + return 0; + } + *retval = bucket.ptr - 1; + return 1; + } + case FIND: { + auto bucket = table->raw.find(item, hasher, equal); + if (!bucket) { + *retval = nullptr; + errno = ESRCH; + return 0; + } + *retval = bucket.ptr - 1; + return 1; + } + } + + return 0; +} +} // namespace __llvm_libc::search::hashtable +#endif // LLVM_LIBC_SRC_SEARCH_HASHTABLE_SEARCH_IMPL_H diff --git a/libc/src/search/hashtable/utils.h b/libc/src/search/hashtable/utils.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hashtable/utils.h @@ -0,0 +1,68 @@ +//===-- Utilities of hashtable --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SEARCH_HASHTABLE_UTILS_H +#define LLVM_LIBC_SRC_SEARCH_HASHTABLE_UTILS_H + +#ifndef LLVM_LIBC_SEARCH_ENABLE_HASHTABLE_RESIZE +#define LLVM_LIBC_SEARCH_ENABLE_HASHTABLE_RESIZE 0 +#endif + +#ifndef LLVM_LIBC_SEARCH_ENABLE_HASHTABLE_DELETION +#define LLVM_LIBC_SEARCH_ENABLE_HASHTABLE_DELETION 0 +#endif + +#include "src/__support/swisstable.h" +#include "src/__support/wyhash.h" +#include "src/string/strcmp.h" +#include "src/string/strlen.h" +#include + +namespace __llvm_libc::search::hashtable { + +static inline bool equal(const ENTRY &a, const ENTRY &b) { + return strcmp(a.key, b.key) == 0; +} + +using RawTable = + cpp::swisstable::RawTable; + +struct SeededTable { + RawTable raw; + uint64_t seed; +}; + +using Bucket = cpp::swisstable::Bucket; + +struct TableHeader { + SeededTable *table; +}; + +static inline SeededTable *get_table(hsearch_data *hdata) { + return reinterpret_cast(hdata)->table; +} + +static inline void set_table(hsearch_data *hdata, SeededTable *table) { + reinterpret_cast(hdata)->table = table; +} + +struct Hash { + uint64_t seed; + uint64_t operator()(const ENTRY &entry) const { + size_t length = __llvm_libc::strlen(entry.key); + return cpp::wyhash::DefaultHash::hash(entry.key, length, seed); + } +}; + +// For now, we use only for seed to initialize the table. +// The number comes from Kunth's PRNG. +static constexpr inline uint64_t DEFAULT_SEED = 6364136223846793005; + +} // namespace __llvm_libc::search::hashtable +#endif // LLVM_LIBC_SRC_SEARCH_HASHTABLE_UTILS_H diff --git a/libc/src/search/hcreate.h b/libc/src/search/hcreate.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hcreate.h @@ -0,0 +1,18 @@ +//===-- Implementation header for hcreate -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SEARCH_HCREATE_H +#define LLVM_LIBC_SRC_SEARCH_HCREATE_H + +#include + +namespace __llvm_libc { +int hcreate(size_t); +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SEARCH_HCREATE_H diff --git a/libc/src/search/hcreate.cpp b/libc/src/search/hcreate.cpp new file mode 100644 --- /dev/null +++ b/libc/src/search/hcreate.cpp @@ -0,0 +1,26 @@ +//===-- Implementation of hcreate -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/search/hcreate.h" +#include "src/search/hashtable/global.h" +#include "src/search/hashtable/utils.h" +#include + +namespace __llvm_libc { +LLVM_LIBC_FUNCTION(int, hcreate, (size_t nel)) { + using namespace search::hashtable; + global_raw_table.raw = RawTable::with_capacity(nel); + global_raw_table.seed = DEFAULT_SEED ^ reinterpret_cast(&nel); + if (!global_raw_table.raw.is_valid()) { + errno = ENOMEM; + return 0; + } + return 1; +} + +} // namespace __llvm_libc diff --git a/libc/src/search/hcreate_r.h b/libc/src/search/hcreate_r.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hcreate_r.h @@ -0,0 +1,18 @@ +//===-- Implementation header for hcreate_r ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SEARCH_HCREATE_R_H +#define LLVM_LIBC_SRC_SEARCH_HCREATE_R_H + +#include // ENTRY, ACTION, hsearch_data + +namespace __llvm_libc { +int hcreate_r(size_t, hsearch_data *); +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SEARCH_HCREATE_R_H diff --git a/libc/src/search/hcreate_r.cpp b/libc/src/search/hcreate_r.cpp new file mode 100644 --- /dev/null +++ b/libc/src/search/hcreate_r.cpp @@ -0,0 +1,45 @@ +//===-- Implementation of hcreate_r -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/search/hcreate_r.h" +#include "src/search/hashtable/utils.h" +#include +#include +#include + +namespace __llvm_libc { +LLVM_LIBC_FUNCTION(int, hcreate_r, (size_t nel, hsearch_data *hdata)) { + using namespace search::hashtable; + + if (hdata == nullptr) { + errno = EINVAL; + return 0; + } + + SeededTable *table = static_cast( + aligned_alloc(alignof(SeededTable), sizeof(SeededTable))); + + if (table == nullptr) { + errno = ENOMEM; + return 0; + } + + table->raw = RawTable::with_capacity(nel); + + if (!table->raw.is_valid()) { + free(table); + errno = ENOMEM; + return 0; + } + + set_table(hdata, table); + table->seed = DEFAULT_SEED ^ reinterpret_cast(table); + return 1; +} + +} // namespace __llvm_libc diff --git a/libc/src/search/hdestroy.h b/libc/src/search/hdestroy.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hdestroy.h @@ -0,0 +1,16 @@ +//===-- Implementation header for hdestroy ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SEARCH_HDESTROY_H +#define LLVM_LIBC_SRC_SEARCH_HDESTROY_H + +namespace __llvm_libc { +void hdestroy(); +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SEARCH_HDESTROY_H diff --git a/libc/src/search/hdestroy.cpp b/libc/src/search/hdestroy.cpp new file mode 100644 --- /dev/null +++ b/libc/src/search/hdestroy.cpp @@ -0,0 +1,18 @@ +//===-- Implementation of hdestroy ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/search/hdestroy.h" +#include "src/search/hashtable/global.h" +namespace __llvm_libc { +LLVM_LIBC_FUNCTION(void, hdestroy, ()) { + using namespace search::hashtable; + global_raw_table.raw.release(); + global_raw_table.seed = 0; +} + +} // namespace __llvm_libc diff --git a/libc/src/search/hdestroy_r.h b/libc/src/search/hdestroy_r.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hdestroy_r.h @@ -0,0 +1,18 @@ +//===-- Implementation header for hdestroy_r --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SEARCH_HDESTROY_R_H +#define LLVM_LIBC_SRC_SEARCH_HDESTROY_R_H + +#include + +namespace __llvm_libc { +void hdestroy_r(hsearch_data *); +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SEARCH_HDESTROY_R_H diff --git a/libc/src/search/hdestroy_r.cpp b/libc/src/search/hdestroy_r.cpp new file mode 100644 --- /dev/null +++ b/libc/src/search/hdestroy_r.cpp @@ -0,0 +1,21 @@ +//===-- Implementation of hdestroy_r ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/search/hdestroy_r.h" +#include "src/search/hashtable/utils.h" +#include +namespace __llvm_libc { +LLVM_LIBC_FUNCTION(void, hdestroy_r, (hsearch_data * hdata)) { + using namespace search::hashtable; + SeededTable *table = get_table(hdata); + table->raw.release(); + table->seed = 0; + free(table); +} + +} // namespace __llvm_libc diff --git a/libc/src/search/hsearch.h b/libc/src/search/hsearch.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hsearch.h @@ -0,0 +1,18 @@ +//===-- Implementation header for hsearch -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SEARCH_HSEARCH_H +#define LLVM_LIBC_SRC_SEARCH_HSEARCH_H + +#include // ENTRY, ACTION, hsearch_data + +namespace __llvm_libc { +ENTRY *hsearch(ENTRY item, ACTION action); +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SEARCH_HSEARCH_H diff --git a/libc/src/search/hsearch.cpp b/libc/src/search/hsearch.cpp new file mode 100644 --- /dev/null +++ b/libc/src/search/hsearch.cpp @@ -0,0 +1,21 @@ +//===-- Implementation of hsearch -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/search/hsearch.h" +#include "src/search/hashtable/global.h" +#include "src/search/hashtable/search_impl.h" + +namespace __llvm_libc { +LLVM_LIBC_FUNCTION(ENTRY *, hsearch, (ENTRY item, ACTION action)) { + using namespace search::hashtable; + ENTRY *e; + search_impl(item, action, &e, &global_raw_table); + return e; +} + +} // namespace __llvm_libc diff --git a/libc/src/search/hsearch_r.h b/libc/src/search/hsearch_r.h new file mode 100644 --- /dev/null +++ b/libc/src/search/hsearch_r.h @@ -0,0 +1,18 @@ +//===-- Implementation header for hsearch_r ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SEARCH_HSEARCH_R_H +#define LLVM_LIBC_SRC_SEARCH_HSEARCH_R_H + +#include // ENTRY, ACTION, hsearch_data + +namespace __llvm_libc { +int hsearch_r(ENTRY item, ACTION action, ENTRY **retval, hsearch_data *); +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SEARCH_HSEARCH_R_H diff --git a/libc/src/search/hsearch_r.cpp b/libc/src/search/hsearch_r.cpp new file mode 100644 --- /dev/null +++ b/libc/src/search/hsearch_r.cpp @@ -0,0 +1,20 @@ +//===-- Implementation of hsearch_r -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/search/hsearch_r.h" +#include "src/search/hashtable/search_impl.h" + +namespace __llvm_libc { +LLVM_LIBC_FUNCTION(int, hsearch_r, + (ENTRY item, ACTION action, ENTRY **retval, + hsearch_data *hdata)) { + using namespace search::hashtable; + return search_impl(item, action, retval, get_table(hdata)); +} + +} // namespace __llvm_libc diff --git a/libc/test/src/CMakeLists.txt b/libc/test/src/CMakeLists.txt --- a/libc/test/src/CMakeLists.txt +++ b/libc/test/src/CMakeLists.txt @@ -32,6 +32,7 @@ add_subdirectory(fenv) add_subdirectory(inttypes) add_subdirectory(math) +add_subdirectory(search) add_subdirectory(string) add_subdirectory(stdlib) add_subdirectory(stdio) diff --git a/libc/test/src/__support/CMakeLists.txt b/libc/test/src/__support/CMakeLists.txt --- a/libc/test/src/__support/CMakeLists.txt +++ b/libc/test/src/__support/CMakeLists.txt @@ -20,6 +20,17 @@ libc.src.__support.common ) +add_libc_unittest( + wyhash_test + SUITE + libc_support_unittests + SRCS + wyhash_test.cpp + DEPENDS + libc.src.__support.wyhash + libc.src.string.strlen +) + add_libc_unittest( high_precision_decimal_test SUITE @@ -111,3 +122,4 @@ add_subdirectory(CPP) add_subdirectory(File) add_subdirectory(OSUtil) +add_subdirectory(swisstable) diff --git a/libc/test/src/__support/swisstable/CMakeLists.txt b/libc/test/src/__support/swisstable/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/CMakeLists.txt @@ -0,0 +1,69 @@ +add_libc_unittest( + safe_mem_size_test + SUITE + libc_support_unittests + SRCS + safe_mem_size_test.cpp + DEPENDS + libc.src.__support.swisstable + libc.src.__support.CPP.limits +) + +add_libc_unittest( + bitmask_test + SUITE + libc_support_unittests + SRCS + bitmask_test.cpp + DEPENDS + libc.src.__support.swisstable +) + +add_libc_unittest( + group_test + SUITE + libc_support_unittests + SRCS + group_test.cpp + COMPILE_OPTIONS + -DSWISSTABLE_TEST_USE_GENERIC_GROUP=0 + DEPENDS + libc.src.__support.swisstable + libc.src.string.memcmp +) + +add_libc_unittest( + group_generic_test + SUITE + libc_support_unittests + SRCS + group_test.cpp + COMPILE_OPTIONS + -DSWISSTABLE_TEST_USE_GENERIC_GROUP=1 + DEPENDS + libc.src.__support.swisstable + libc.src.string.memcmp +) + +add_header_library( + test_utils + HDRS + test_utils.h + DEPENDS + libc.include.stdlib + libc.src.string.memmove +) + +add_libc_unittest( + hashtable_test + SUITE + libc_support_unittests + SRCS + hashtable_test.cpp + DEPENDS + .test_utils + libc.src.__support.swisstable + libc.src.stdlib.rand + libc.src.stdlib.srand + libc.src.sys.random.getrandom +) diff --git a/libc/test/src/__support/swisstable/bitmask_test.cpp b/libc/test/src/__support/swisstable/bitmask_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/bitmask_test.cpp @@ -0,0 +1,141 @@ +#include "src/__support/swisstable.h" +#include "utils/UnitTest/LibcTest.h" +#include +#include +namespace __llvm_libc::cpp::swisstable { + +using ShortBitMask = BitMaskAdaptor; +using LargeBitMask = BitMaskAdaptor; + +TEST(LlvmLibcSwissTableBitMask, Invert) { + { + auto x = ShortBitMask{0b10101010'10101010}; + ASSERT_EQ(x.invert().word, static_cast(0b010101010'1010101)); + } + + { + auto x = LargeBitMask{0x80808080'00000000}; + ASSERT_EQ(x.invert().word, static_cast(0x00000000'80808080)); + } +} + +TEST(LlvmLibcSwissTableBitMask, SingleBitStrideLeadingZeros) { + ASSERT_EQ(ShortBitMask{0x8808}.leading_zeros(), size_t{0}); + ASSERT_EQ(ShortBitMask{0x0808}.leading_zeros(), size_t{4}); + ASSERT_EQ(ShortBitMask{0x0408}.leading_zeros(), size_t{5}); + ASSERT_EQ(ShortBitMask{0x0208}.leading_zeros(), size_t{6}); + ASSERT_EQ(ShortBitMask{0x0108}.leading_zeros(), size_t{7}); + ASSERT_EQ(ShortBitMask{0x0008}.leading_zeros(), size_t{12}); + + uint16_t data = 0xffff; + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(ShortBitMask{data}.leading_zeros(), i); + data >>= 1; + } +} + +TEST(LlvmLibcSwissTableBitMask, MultiBitStrideLeadingZeros) { + ASSERT_EQ(LargeBitMask{0x80808080'80808080}.leading_zeros(), size_t{0}); + ASSERT_EQ(LargeBitMask{0x00808080'80808080}.leading_zeros(), size_t{1}); + ASSERT_EQ(LargeBitMask{0x00000080'80808080}.leading_zeros(), size_t{3}); + + ASSERT_EQ(LargeBitMask{0x80808080'80808080}.leading_zeros(), size_t{0}); + ASSERT_EQ(LargeBitMask{0x01808080'80808080}.leading_zeros(), size_t{0}); + ASSERT_EQ(LargeBitMask{0x10000080'80808080}.leading_zeros(), size_t{0}); + ASSERT_EQ(LargeBitMask{0x0F808080'80808080}.leading_zeros(), size_t{0}); + + uint64_t data = 0xffff'ffff'ffff'ffff; + for (size_t i = 0; i < 8; ++i) { + for (size_t j = 0; j < 8; ++j) { + ASSERT_EQ(LargeBitMask{data}.leading_zeros(), i); + data >>= 1; + } + } +} + +TEST(LlvmLibcSwissTableBitMask, SingleBitStrideTrailingZeros) { + ASSERT_EQ(ShortBitMask{0x8808}.trailing_zeros(), size_t{3}); + ASSERT_EQ(ShortBitMask{0x0804}.trailing_zeros(), size_t{2}); + ASSERT_EQ(ShortBitMask{0x0802}.trailing_zeros(), size_t{1}); + ASSERT_EQ(ShortBitMask{0x0801}.trailing_zeros(), size_t{0}); + ASSERT_EQ(ShortBitMask{0x0800}.trailing_zeros(), size_t{11}); + ASSERT_EQ(ShortBitMask{0x1000}.trailing_zeros(), size_t{12}); + + uint16_t data = 0xffff; + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(ShortBitMask{data}.trailing_zeros(), i); + data <<= 1; + } +} + +TEST(LlvmLibcSwissTableBitMask, MultiBitStrideTrailingZeros) { + ASSERT_EQ(LargeBitMask{0x80808080'80808080}.trailing_zeros(), size_t{0}); + ASSERT_EQ(LargeBitMask{0x80808080'80808000}.trailing_zeros(), size_t{1}); + ASSERT_EQ(LargeBitMask{0x80808080'80804000}.trailing_zeros(), size_t{1}); + + ASSERT_EQ(LargeBitMask{0x80808080'80800000}.trailing_zeros(), size_t{2}); + ASSERT_EQ(LargeBitMask{0x80808080'80000000}.trailing_zeros(), size_t{3}); + ASSERT_EQ(LargeBitMask{0x80808080'00000000}.trailing_zeros(), size_t{4}); + ASSERT_EQ(LargeBitMask{0x80808000'00000000}.trailing_zeros(), size_t{5}); + + uint64_t data = 0xffff'ffff'ffff'ffff; + for (size_t i = 0; i < 8; ++i) { + for (size_t j = 0; j < 8; ++j) { + ASSERT_EQ(LargeBitMask{data}.trailing_zeros(), i); + data <<= 1; + } + } +} + +TEST(LlvmLibcSwissTableBitMask, SingleBitStrideLowestSetBit) { + uint16_t data = 0xffff; + for (size_t i = 0; i < 16; ++i) { + if (ShortBitMask{data}.any_bit_set()) { + ASSERT_EQ(ShortBitMask{data}.lowest_set_bit_nonzero(), i); + data <<= 1; + } + } +} + +TEST(LlvmLibcSwissTableBitMask, MultiBitStrideLowestSetBit) { + uint64_t data = 0xffff'ffff'ffff'ffff; + for (size_t i = 0; i < 8; ++i) { + for (size_t j = 0; j < 8; ++j) { + if (LargeBitMask{data}.any_bit_set()) { + ASSERT_EQ(LargeBitMask{data}.lowest_set_bit_nonzero(), i); + data <<= 1; + } + } + } +} + +TEST(LlvmLibcSwissTableBitMask, SingleBitStrideIteration) { + using Iter = IteratableBitMaskAdaptor; + uint16_t data = 0xffff; + for (size_t i = 0; i < 16; ++i) { + Iter iter = {data}; + size_t j = i; + for (auto x : iter) { + ASSERT_EQ(x, j); + j++; + } + ASSERT_EQ(j, size_t{16}); + data <<= 1; + } +} + +TEST(LlvmLibcSwissTableBitMask, MultiBitStrideIteration) { + using Iter = IteratableBitMaskAdaptor; + uint64_t data = Iter::MASK; + for (size_t i = 0; i < 8; ++i) { + Iter iter = {data}; + size_t j = i; + for (auto x : iter) { + ASSERT_EQ(x, j); + j++; + } + ASSERT_EQ(j, size_t{8}); + data <<= Iter::STRIDE; + } +} +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/test/src/__support/swisstable/group_test.cpp b/libc/test/src/__support/swisstable/group_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/group_test.cpp @@ -0,0 +1,120 @@ +#include +#if SWISSTABLE_TEST_USE_GENERIC_GROUP +#include "src/__support/swisstable/generic.h" +#define SWISSTABLE_TEST_SUITE(X) TEST(LlvmLibcSwissTableGroupGeneric, X) +#else +#include "src/__support/swisstable/dispatch.h" +#define SWISSTABLE_TEST_SUITE(X) TEST(LlvmLibcSwissTableGroup, X) +#endif +#include "src/string/memcmp.h" +#include "utils/UnitTest/LibcTest.h" +#include + +namespace __llvm_libc::cpp::swisstable { + +struct ByteArray { + alignas(Group) uint8_t data[sizeof(Group) + 1]{}; +}; + +SWISSTABLE_TEST_SUITE(LoadStore) { + ByteArray array{}; + ByteArray compare{}; + for (size_t i = 0; i < sizeof(array.data); ++i) { + array.data[i] = 0xff; + auto group = Group::aligned_load(array.data); + group.aligned_store(compare.data); + EXPECT_EQ(__llvm_libc::memcmp(compare.data, array.data, sizeof(Group)), 0); + group = Group::load(&array.data[1]); + group.aligned_store(compare.data); + EXPECT_EQ(__llvm_libc::memcmp(compare.data, &array.data[1], sizeof(Group)), + 0); + array.data[i] = 0; + } +} + +SWISSTABLE_TEST_SUITE(Match) { + // Any pair of targets have bit differences not only at the lowest bit. + // No False positive. + uint8_t targets[4] = {0x00, 0x11, 0xFF, 0x0F}; + size_t count[4] = {0, 0, 0, 0}; + size_t appearance[4][sizeof(Group)]; + ByteArray array{}; + + uintptr_t random = reinterpret_cast(&array) ^ + reinterpret_cast(aligned_alloc); + + for (size_t i = 0; i < sizeof(Group); ++i) { + size_t choice = random % 4; + random /= 4; + array.data[i] = targets[choice]; + appearance[choice][count[choice]++] = i; + } + + for (size_t t = 0; t < sizeof(targets); ++t) { + auto bitmask = Group::aligned_load(array.data).match_byte(targets[t]); + for (size_t i = 0; i < count[t]; ++i) { + size_t iterated = 0; + for (size_t position : bitmask) { + ASSERT_EQ(appearance[t][iterated], position); + iterated++; + } + ASSERT_EQ(count[t], iterated); + } + } +} + +SWISSTABLE_TEST_SUITE(MaskEmpty) { + uint8_t values[4] = {0x00, 0x0F, DELETED, EMPTY}; + + for (size_t i = 0; i < sizeof(Group); ++i) { + ByteArray array{}; + + intptr_t random = reinterpret_cast(&array) ^ + reinterpret_cast(aligned_alloc); + ASSERT_FALSE(Group::aligned_load(array.data).mask_empty().any_bit_set()); + + array.data[i] = EMPTY; + for (size_t j = 0; j < sizeof(Group); ++j) { + if (i == j) + continue; + size_t sample_space = 3 + (j > i); + size_t choice = random % sample_space; + random /= sizeof(values); + array.data[j] = values[choice]; + } + + auto mask = Group::aligned_load(array.data).mask_empty(); + ASSERT_TRUE(mask.any_bit_set()); + ASSERT_EQ(mask.lowest_set_bit_nonzero(), i); + } +} + +SWISSTABLE_TEST_SUITE(MaskEmptyOrDeleted) { + uint8_t values[4] = {0x00, 0x0F, DELETED, EMPTY}; + for (size_t t = 2; t <= 3; ++t) { + for (size_t i = 0; i < sizeof(Group); ++i) { + ByteArray array{}; + + intptr_t random = reinterpret_cast(&array) ^ + reinterpret_cast(aligned_alloc); + ASSERT_FALSE(Group::aligned_load(array.data) + .mask_empty_or_deleted() + .any_bit_set()); + + array.data[i] = values[t]; + for (size_t j = 0; j < sizeof(Group); ++j) { + if (i == j) + continue; + size_t sample_space = 2 + 2 * (j > i); + size_t choice = random % sample_space; + random /= sizeof(values); + array.data[j] = values[choice]; + } + + auto mask = Group::aligned_load(array.data).mask_empty_or_deleted(); + ASSERT_TRUE(mask.any_bit_set()); + ASSERT_EQ(mask.lowest_set_bit_nonzero(), i); + } + } +} +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/test/src/__support/swisstable/hashtable_test.cpp b/libc/test/src/__support/swisstable/hashtable_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/hashtable_test.cpp @@ -0,0 +1,212 @@ +#include "llvm-libc-macros/stdlib-macros.h" +#include "src/__support/builtin_wrappers.h" +#include "src/__support/swisstable.h" +#include "src/stdlib/rand.h" +#include "src/stdlib/srand.h" +#include "src/sys/random/getrandom.h" +#include "test/src/__support/swisstable/test_utils.h" +#include "utils/UnitTest/LibcTest.h" + +namespace __llvm_libc::cpp::swisstable { + +enum TestAction { Insert = 0, Find = 1, Delete = 2 }; + +using FixedInsertOnlyTable = RawTable; +using ResizableInsertOnlyTable = RawTable; +using FixedTable = RawTable; +using ResizableTable = RawTable; + +static inline uint64_t sample() { + size_t length = 64 - safe_clz(RAND_MAX); + uint64_t accumulator = 0; + for (size_t i = 0; i < 64; i += length) { + accumulator |= static_cast(rand()) << i; + } + return accumulator; +} + +static inline uint64_t bad_sample(uint64_t x) { return x << 16; } + +static inline TestAction next_move(int max = 3) { + return static_cast(((rand() % max) + max) % max); +} + +static inline bool test_equal(const uint64_t &x, const uint64_t &y) { + return x == y; +} +static inline size_t test_hash(const uint64_t &x) { return x; } + +template +struct LlvmLibcSwissTableTest : testing::Test { + + void init_seed() { + unsigned int seed; + __llvm_libc::getrandom(&seed, sizeof(seed), 0); + __llvm_libc::srand(seed); + } + + void bad_hash_example() { + auto table = Table::with_capacity(16384); + for (uint64_t i = 0; i < 16384; ++i) { + auto t = bad_sample(i); + table.insert(t, test_hash); + } + for (uint64_t i = 0; i < 16384; ++i) { + ASSERT_EQ(*table.find(bad_sample(i), test_hash, test_equal), + bad_sample(i)); + } + for (uint64_t i = 16384; i < 32768; ++i) { + ASSERT_FALSE(static_cast(table.find(i, test_hash, test_equal))); + ASSERT_FALSE( + static_cast(table.find(bad_sample(i), test_hash, test_equal))); + } + table.release(); + } + + void random_walk(size_t attempts, size_t capacity) { + BrutalForceSet set; + auto table = Table::with_capacity(capacity); + for (size_t i = 0; i < attempts; ++i) { + switch (next_move(OpRange)) { + case Insert: { + auto x = sample(); + set.insert(x); + if (x & 1) { + if (!table.find(x, test_hash, test_equal)) { + ASSERT_EQ(*table.insert(x, test_hash), x); + } + } else { + ASSERT_EQ(*table.find_or_insert(x, test_hash, test_equal), x); + } + ASSERT_EQ(*table.find(x, test_hash, test_equal), x); + + break; + } + case Find: { + if (!set.size) + continue; + auto x = sample(); + ASSERT_EQ(static_cast(table.find(x, test_hash, test_equal)), + set.find(x)); + auto index = x % set.size; + ASSERT_EQ(*table.find(set.data[index], test_hash, test_equal), + set.data[index]); + break; + } + case Delete: { + if (!set.size) + continue; + auto x = sample(); + auto index = x % set.size; + auto target = set.data[index]; + set.erase(target); + ASSERT_EQ(*table.find(target, test_hash, test_equal), target); + table.erase(table.find(target, test_hash, test_equal)); + ASSERT_FALSE( + static_cast(table.find(target, test_hash, test_equal))); + break; + } + } + } + for (size_t i = 0; i < set.size; ++i) { + ASSERT_EQ(*table.find(set.data[i], test_hash, test_equal), set.data[i]); + } + for (size_t i = 0; i < attempts; ++i) { + auto x = sample(); + ASSERT_EQ(static_cast(table.find(x, test_hash, test_equal)), + set.find(x)); + } + table.release(); + } +}; + +using LlvmLibcSwissTableFixedInsertOnlyTableTest = + LlvmLibcSwissTableTest; +using LlvmLibcSwissTableResizableInsertOnlyTableTest = + LlvmLibcSwissTableTest; +using LlvmLibcSwissTableFixedTableTest = LlvmLibcSwissTableTest; +using LlvmLibcSwissTableResizableTableTest = + LlvmLibcSwissTableTest; + +TEST_F(LlvmLibcSwissTableFixedInsertOnlyTableTest, RandomWalk) { + init_seed(); + for (size_t i = 0; i < 100; ++i) { + this->random_walk(10000, 10000); + } +} +TEST_F(LlvmLibcSwissTableResizableInsertOnlyTableTest, RandomWalk) { + init_seed(); + for (size_t i = 0; i < 100; ++i) { + this->random_walk(10000, 0); + } +} +TEST_F(LlvmLibcSwissTableFixedTableTest, RandomWalk) { + init_seed(); + for (size_t i = 0; i < 100; ++i) { + this->random_walk(10000, 10000); + } +} +TEST_F(LlvmLibcSwissTableResizableTableTest, RandomWalk) { + init_seed(); + for (size_t i = 0; i < 100; ++i) { + this->random_walk(10000, 0); + } +} + +TEST_F(LlvmLibcSwissTableFixedInsertOnlyTableTest, BadHash) { + bad_hash_example(); +} +TEST_F(LlvmLibcSwissTableResizableInsertOnlyTableTest, BadHash) { + bad_hash_example(); +} +TEST_F(LlvmLibcSwissTableFixedTableTest, BadHash) { bad_hash_example(); } +TEST_F(LlvmLibcSwissTableResizableTableTest, BadHash) { bad_hash_example(); } + +TEST(LlvmLibcSwissTableFixedInsertOnlyTableTest, Pressure) { + auto table = FixedInsertOnlyTable::with_capacity(512); + uint64_t i = 0; + for (; i < 512; ++i) { + if (!table.insert(i, test_hash)) { + break; + } + } + ASSERT_LE(i, static_cast(512)); + table.release(); +} + +TEST(LlvmLibcSwissTableFixedTableTest, InPlaceRehash) { + auto table = FixedTable::with_capacity(512); + uint64_t i = 0; + for (; i < 512; ++i) { + if (!table.insert(i, test_hash)) { + break; + } + } + + uint64_t n = i; + + ASSERT_GE(n, static_cast(128)); + + for (i = 0; i < 128; ++i) { + table.erase(table.find(i, test_hash, test_equal)); + } + + for (i = 1000; i < 1128; ++i) { + table.find_or_insert(i, test_hash, test_equal); + } + + for (i = 0; i < 128; ++i) { + ASSERT_FALSE(static_cast(table.find(i, test_hash, test_equal))); + } + + for (i = n; i < 512; ++i) { + ASSERT_EQ(*table.find(i, test_hash, test_equal), i); + } + + for (i = 1000; i < 1128; ++i) { + ASSERT_EQ(*table.find(i, test_hash, test_equal), i); + } + + table.release(); +} +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/test/src/__support/swisstable/safe_mem_size_test.cpp b/libc/test/src/__support/swisstable/safe_mem_size_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/safe_mem_size_test.cpp @@ -0,0 +1,60 @@ +#include "src/__support/CPP/limits.h" +#include "src/__support/swisstable.h" +#include "utils/UnitTest/LibcTest.h" + +namespace __llvm_libc::cpp::swisstable { + +static inline constexpr size_t SAFE_MEM_SIZE_TEST_LIMIT = + static_cast(__llvm_libc::cpp::numeric_limits::max()); + +TEST(LlvmLibcSwissTableSafeMemSize, Constuction) { + ASSERT_FALSE(SafeMemSize{static_cast(-1)}.valid()); + ASSERT_FALSE(SafeMemSize{static_cast(-2)}.valid()); + ASSERT_FALSE(SafeMemSize{static_cast(-1024 + 33)}.valid()); + ASSERT_FALSE(SafeMemSize{static_cast(-1024 + 66)}.valid()); + ASSERT_FALSE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT + 1}.valid()); + ASSERT_FALSE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT + 13}.valid()); + + ASSERT_TRUE(SafeMemSize{static_cast(1)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(1024 + 13)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(2048 - 13)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(4096 + 1)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(8192 - 1)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(16384 + 15)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(32768 * 3)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(65536 * 13)}.valid()); + ASSERT_TRUE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT}.valid()); + ASSERT_TRUE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT - 1}.valid()); + ASSERT_TRUE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT - 13}.valid()); +} + +TEST(LlvmLibcSwissTableSafeMemSize, Addition) { + auto max = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT}; + auto half = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT / 2}; + auto third = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT / 3}; + + ASSERT_TRUE(half.valid()); + ASSERT_TRUE(third.valid()); + ASSERT_TRUE((half + half).valid()); + ASSERT_TRUE((third + third + third).valid()); + ASSERT_TRUE((half + third).valid()); + + ASSERT_FALSE((max + SafeMemSize{static_cast(1)}).valid()); + ASSERT_FALSE((third + third + third + third).valid()); + ASSERT_FALSE((half + half + half).valid()); +} + +TEST(LlvmLibcSwissTableSafeMemSize, Multiplication) { + auto max = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT}; + auto half = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT / 2}; + auto third = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT / 3}; + + ASSERT_TRUE((max * SafeMemSize{static_cast(1)}).valid()); + ASSERT_TRUE((max * SafeMemSize{static_cast(0)}).valid()); + + ASSERT_FALSE((max * SafeMemSize{static_cast(2)}).valid()); + ASSERT_FALSE((half * half).valid()); + ASSERT_FALSE((half * SafeMemSize{static_cast(3)}).valid()); + ASSERT_FALSE((third * SafeMemSize{static_cast(4)}).valid()); +} +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/test/src/__support/swisstable/test_utils.h b/libc/test/src/__support/swisstable/test_utils.h new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/test_utils.h @@ -0,0 +1,62 @@ +#include +#include +#include + +namespace __llvm_libc::cpp::swisstable { + +struct BrutalForceSet { + size_t capacity; + size_t size; + uint64_t *data; + + BrutalForceSet() : capacity(16), size(0) { + data = static_cast(calloc(capacity, sizeof(uint64_t))); + } + + uint64_t *binary_search(uint64_t target) { + size_t l = 0, r = size; + while (l < r) { + auto mid = (l + r) / 2; + if (data[mid] == target) { + return &data[mid]; + } + if (data[mid] < target) { + l = mid + 1; + } + if (data[mid] > target) { + r = mid; + } + } + return &data[l]; + } + + void grow() { + data = + static_cast(realloc(data, sizeof(uint64_t) * capacity * 2)); + capacity *= 2; + } + + void insert(uint64_t target) { + if (size == capacity) + grow(); + uint64_t *position = binary_search(target); + if (*position != target) { + __llvm_libc::memmove(position + 1, position, + sizeof(uint64_t) * (data + size - position)); + size++; + *position = target; + } + } + + bool find(uint64_t target) { return target == *binary_search(target); } + + void erase(uint64_t target) { + uint64_t *position = binary_search(target); + if (*position == target) { + __llvm_libc::memmove(position, position + 1, + sizeof(uint64_t) * (data + size - position - 1)); + size--; + } + } +}; +} // namespace __llvm_libc::cpp::swisstable diff --git a/libc/test/src/__support/wyhash_test.cpp b/libc/test/src/__support/wyhash_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/wyhash_test.cpp @@ -0,0 +1,40 @@ +//===-- Unittests for wyhash ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/wyhash.h" +#include "src/string/strlen.h" +#include "utils/UnitTest/LibcTest.h" + +// Examination Data from SMHasher +TEST(LlvmLibcWyHashTest, DefaultValues) { + using namespace __llvm_libc::cpp::wyhash; + // clang-format off + const char *data[] = { + "", + "a", + "abc", + "message digest", + "abcdefghijklmnopqrstuvwxyz", + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789", + "12345678901234567890123456789012345678901234567890123456789012345678901234567890" + }; + uint64_t hash[] = { + 0x42bc986dc5eec4d3, + 0x84508dc903c31551, + 0xbc54887cfc9ecb1, + 0x6e2ff3298208a67c, + 0x9a64e42e897195b9, + 0x9199383239c32554, + 0x7c1ccf6bba30f5a5, + }; + // clang-format on + for (size_t i = 0; i < 7; ++i) { + ASSERT_EQ(DefaultHash::hash(data[i], __llvm_libc::strlen(data[i]), i), + hash[i]); + } +} diff --git a/libc/test/src/search/CMakeLists.txt b/libc/test/src/search/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/libc/test/src/search/CMakeLists.txt @@ -0,0 +1,16 @@ +add_libc_testsuite(libc_search_unittests) +add_libc_unittest( + hsearch_test + SUITE + libc_search_unittests + SRCS + hsearch_test.cpp + DEPENDS + libc.src.search.hsearch_r + libc.src.search.hcreate_r + libc.src.search.hdestroy_r + libc.src.search.hsearch + libc.src.search.hcreate + libc.src.search.hdestroy + libc.src.errno.errno +) \ No newline at end of file diff --git a/libc/test/src/search/hsearch_test.cpp b/libc/test/src/search/hsearch_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/search/hsearch_test.cpp @@ -0,0 +1,117 @@ +//===-- Unittests for hsearch ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm-libc-types/struct_hsearch_data.h" +#include "src/search/hcreate.h" +#include "src/search/hcreate_r.h" +#include "src/search/hdestroy.h" +#include "src/search/hdestroy_r.h" +#include "src/search/hsearch.h" +#include "test/ErrnoSetterMatcher.h" +#include "utils/UnitTest/LibcTest.h" +#include + +TEST(LlvmLibcHsearchTest, CreateTooLarge) { + using __llvm_libc::testing::ErrnoSetterMatcher::Fails; + struct hsearch_data hdata; + ASSERT_THAT(__llvm_libc::hcreate(-1), Fails(ENOMEM, 0)); + ASSERT_THAT(__llvm_libc::hcreate_r(-1, &hdata), Fails(ENOMEM, 0)); +} + +TEST(LlvmLibcHSearchTest, CreateInvalid) { + using __llvm_libc::testing::ErrnoSetterMatcher::Fails; + ASSERT_THAT(__llvm_libc::hcreate_r(16, nullptr), Fails(EINVAL, 0)); +} + +TEST(LlvmLibcHSearchTest, CreateValid) { + struct hsearch_data hdata; + ASSERT_GT(__llvm_libc::hcreate_r(16, &hdata), 0); + hdestroy_r(&hdata); + + ASSERT_GT(__llvm_libc::hcreate(16), 0); + hdestroy(); +} + +char search_data[] = "1234567890abcdefghijklmnopqrstuvwxyz" + "1234567890abcdefghijklmnopqrstuvwxyz" + "1234567890abcdefghijklmnopqrstuvwxyz" + "1234567890abcdefghijklmnopqrstuvwxyz" + "1234567890abcdefghijklmnopqrstuvwxyz"; +char search_data2[] = + "@@@@@@@@@@@@@@!!!!!!!!!!!!!!!!!###########$$$$$$$$$$^^^^^^&&&&&&&&"; + +TEST(LlvmLibcHSearchTest, InsertTooMany) { + using __llvm_libc::testing::ErrnoSetterMatcher::Fails; + ASSERT_GT(__llvm_libc::hcreate(16), 0); + for (size_t i = 0; i < 16 * 7; ++i) { + ASSERT_EQ(__llvm_libc::hsearch({&search_data[i], nullptr}, ENTER)->key, + &search_data[i]); + } + ASSERT_THAT( + static_cast(__llvm_libc::hsearch({search_data2, nullptr}, ENTER)), + Fails(ENOMEM, static_cast(nullptr))); + hdestroy(); +} + +TEST(LlvmLibcHSearchTest, NotFound) { + using __llvm_libc::testing::ErrnoSetterMatcher::Fails; + ASSERT_GT(__llvm_libc::hcreate(16), 0); + ASSERT_THAT( + static_cast(__llvm_libc::hsearch({search_data2, nullptr}, FIND)), + Fails(ESRCH, static_cast(nullptr))); + for (size_t i = 0; i < 16 * 7; ++i) { + ASSERT_EQ(__llvm_libc::hsearch({&search_data[i], nullptr}, ENTER)->key, + &search_data[i]); + } + ASSERT_THAT( + static_cast(__llvm_libc::hsearch({search_data2, nullptr}, FIND)), + Fails(ESRCH, static_cast(nullptr))); + hdestroy(); +} + +TEST(LlvmLibcHSearchTest, Found) { + using __llvm_libc::testing::ErrnoSetterMatcher::Fails; + ASSERT_GT(__llvm_libc::hcreate(16), 0); + for (size_t i = 0; i < 16 * 7; ++i) { + ASSERT_EQ(__llvm_libc::hsearch( + {&search_data[i], reinterpret_cast(i)}, ENTER) + ->key, + &search_data[i]); + } + for (size_t i = 0; i < 16 * 7; ++i) { + ASSERT_EQ(__llvm_libc::hsearch({&search_data[i], nullptr}, FIND)->data, + reinterpret_cast(i)); + } + hdestroy(); +} + +TEST(LlvmLibcHSearchTest, OnlyInsertWhenNotFound) { + using __llvm_libc::testing::ErrnoSetterMatcher::Fails; + ASSERT_GT(__llvm_libc::hcreate(16), 0); + for (size_t i = 0; i < 16 * 5; ++i) { + ASSERT_EQ(__llvm_libc::hsearch( + {&search_data[i], reinterpret_cast(i)}, ENTER) + ->key, + &search_data[i]); + } + for (size_t i = 0; i < 16 * 7; ++i) { + ASSERT_EQ(__llvm_libc::hsearch( + {&search_data[i], reinterpret_cast(1000 + i)}, ENTER) + ->key, + &search_data[i]); + } + for (size_t i = 0; i < 16 * 5; ++i) { + ASSERT_EQ(__llvm_libc::hsearch({&search_data[i], nullptr}, FIND)->data, + reinterpret_cast(i)); + } + for (size_t i = 16 * 5; i < 16 * 7; ++i) { + ASSERT_EQ(__llvm_libc::hsearch({&search_data[i], nullptr}, FIND)->data, + reinterpret_cast(1000 + i)); + } + hdestroy(); +}