diff --git a/libc/benchmarks/CMakeLists.txt b/libc/benchmarks/CMakeLists.txt --- a/libc/benchmarks/CMakeLists.txt +++ b/libc/benchmarks/CMakeLists.txt @@ -119,7 +119,7 @@ EXCLUDE_FROM_ALL JSON.cpp JSON.h -) + ../src/__support/swisstable/safe_mem_size.h) target_link_libraries(json PUBLIC libc-memory-benchmark) llvm_update_compile_flags(json) diff --git a/libc/src/__support/CMakeLists.txt b/libc/src/__support/CMakeLists.txt --- a/libc/src/__support/CMakeLists.txt +++ b/libc/src/__support/CMakeLists.txt @@ -191,6 +191,20 @@ .uint128 ) +add_header_library( + swisstable + HDRS + swisstable.h + DEPENDS + .common + .builtin_wrappers + .swisstable.dispatch + .swisstable.safe_mem_size + libc.src.string.memcpy + libc.src.string.memset + libc.include.stdlib +) + add_subdirectory(FPUtil) add_subdirectory(OSUtil) add_subdirectory(StringUtil) diff --git a/libc/src/__support/swisstable.h b/libc/src/__support/swisstable.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/swisstable.h @@ -0,0 +1,592 @@ +//===-- SwissTable ----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SUPPORT_SWISSTABLE_H +#define LLVM_LIBC_SUPPORT_SWISSTABLE_H + +#include "src/__support/CPP/array.h" +#include "src/__support/builtin_wrappers.h" +#include "src/__support/common.h" +#include "src/__support/swisstable/dispatch.h" +#include "src/__support/swisstable/safe_mem_size.h" +#include "src/string/memcpy.h" +#include "src/string/memory_utils/memcpy_implementations.h" +#include "src/string/memset.h" +#include + +namespace __llvm_libc::internal::swisstable { +static inline bool is_full(CtrlWord ctrl) { return (ctrl & 0x80) == 0; } +static inline bool special_is_empty(CtrlWord ctrl) { + return (ctrl & 0x01) != 0; +} +static inline size_t h1(HashWord hash) { return static_cast(hash); } +static inline CtrlWord h2(HashWord hash) { + static constexpr size_t HASH_LENGTH = + sizeof(size_t) < sizeof(HashWord) ? sizeof(size_t) : sizeof(HashWord); + // We want the top 7 bits of the h1. + return (hash >> (HASH_LENGTH * 8 - 7)) & 0x7f; +} + +struct ProbeSequence { + size_t position; + size_t stride; + + void move_next(size_t bucket_mask) { + stride += sizeof(Group); + position += stride; + position &= bucket_mask; + } +}; + +static inline size_t next_power_of_two(size_t val) { + size_t idx = __llvm_libc::unsafe_clz(val - 1); + return 1ull << ((8ull * sizeof(size_t)) - idx); +} + +static inline size_t capacity_to_buckets(size_t cap) { + if (cap < 8) { + return (cap < 4) ? 4 : 8; + } + return next_power_of_two(cap * 8); +} + +static inline size_t bucket_mask_to_capacity(size_t bucket_mask) { + if (bucket_mask < 8) { + return bucket_mask; + } else { + return (bucket_mask + 1) / 8 * 7; + } +} + +template struct ConstantStorage { + alignas(Group) T data[SIZE]; + constexpr ConstantStorage(CtrlWord val) { + for (size_t i = 0; i < SIZE; ++i) { + data[i] = val; + } + } +}; + +constexpr static inline ConstantStorage + CONST_EMPTY_GROUP = {EMPTY}; + +// The heap memory layout for N buckets of size S: +// +// ========================================= +// Fields: | buckets | ctrl bytes | group | +// ----------------------------------------- +// Size: | N*S | N | sizeof(Group) | +// ========================================= +// ^ +// | +// Store this position in RawTable. +// +// The trailing group part is to make sure we can always load +// a whole group of control bytes. +template struct TableLayout { + // Pick the largest alignment between T and Group. + static constexpr inline size_t ALIGNMENT = alignof(T) > alignof(Group) + ? alignof(T) + : alignof(Group); + size_t offset; + size_t size; + + TableLayout(size_t offset, size_t size) : offset(offset), size(size) {} + + // We want to find an aligned boundary, put buckets to its left and + // put ctrl words to its left. So we just trim the trailing ones from + // (buckets * sizeof(T) + alignment - 1) + static TableLayout checked(size_t buckets, bool &valid) { + valid = true; + SafeMemSize padded = SafeMemSize{buckets} * SafeMemSize{sizeof(T)} + + SafeMemSize{ALIGNMENT - 1}; + valid = valid && padded.valid(); + size_t offset = static_cast(padded) & ~(ALIGNMENT - 1); + + SafeMemSize safe_size = + SafeMemSize{offset} + SafeMemSize{buckets} + SafeMemSize{sizeof(Group)}; + valid = valid && safe_size.valid(); + size_t size = static_cast(safe_size); + + return {offset, size}; + } + + static TableLayout unchecked(size_t buckets) { + size_t offset = (buckets * sizeof(T) + ALIGNMENT - 1) & ~(ALIGNMENT - 1); + size_t size = offset + buckets + sizeof(Group); + return {offset, size}; + } +}; + +// We will not consider zero-sized type +template struct Bucket { + static_assert(sizeof(T) > 1, "zero sized type not allowed"); + + T *ptr; + + Bucket(T *base, size_t index) : ptr(base - index) {} + + size_t base_index(T *base) const { return base - ptr; } + + T &operator*() { return ptr[-1]; } + + Bucket operator+(size_t offset) { return {ptr - offset}; } + + bool operator==(const Bucket &other) { return ptr == other.ptr; } + + operator bool() { return ptr != nullptr; } +}; + +// T should be trivial type. +template class RawTable { + using Layout = TableLayout; + // Bucket size is a power of two + size_t bucket_mask; + // Pointer to the control words + CtrlWord *ctrl; + // Number of items before growth is required + size_t growth_left; + // Number of items + size_t items; + +public: + static RawTable create_invalid() { + RawTable table{}; + table.bucket_mask = 0; + table.ctrl = nullptr; + table.growth_left = 0; + table.items = 0; + return table; + } + + static RawTable uninitialized(size_t buckets) { + RawTable table = RawTable::create(); + bool valid; + Layout layout = Layout::checked(buckets, valid); + + if (unlikely(!valid)) + return table; + + CtrlWord *address = + static_cast(aligned_alloc(Layout::ALIGNMENT, layout.size)); + if (unlikely(address == nullptr)) + return table; + + table.bucket_mask = buckets - 1; + table.ctrl = address + layout.offset; + table.growth_left = bucket_mask_to_capacity(table.bucket_mask); + return table; + } + + static RawTable create() { + RawTable table{}; + table.ctrl = const_cast(&CONST_EMPTY_GROUP.data[0]); + table.bucket_mask = 0; + table.growth_left = 0; + table.items = 0; + return table; + } + + static RawTable with_capacity(size_t capacity) { + if (capacity == 0) + return create(); + + // additional check for capacity overflow + auto safe_cap = SafeMemSize{capacity}; + safe_cap = safe_cap * SafeMemSize{size_t{8}}; + if (!safe_cap.valid()) { + return create_invalid(); + } + + size_t buckets = capacity_to_buckets(capacity); + RawTable table = uninitialized(buckets); + if (likely(table.ctrl != nullptr)) + memset(reinterpret_cast(table.ctrl), EMPTY, + table.num_ctrl_words()); + return table; + } + + void release() { + if (!is_empty_singleton()) { + Layout layout = Layout::unchecked(num_buckets()); + free(ctrl - layout.offset); + } + ctrl = nullptr; + } + + bool is_valid() { return ctrl != nullptr; } + +private: + size_t num_buckets() const { return bucket_mask + 1; } + size_t num_ctrl_words() const { return num_buckets() + sizeof(Group); } + + T *data_end() const { return reinterpret_cast(ctrl); } + T *data_begin() const { return reinterpret_cast(ctrl) - num_buckets(); } + + Bucket bucket_at(size_t index) const { return {data_end(), index}; } + size_t bucket_index(const Bucket &bucket) const { + return bucket.base_index(data_end()); + } + + // Sets a control byte, and possibly also the replicated control byte at + // the end of the array. + void set_ctrl(size_t index, CtrlWord value) { + // Replicate the first sizeof(Group) control bytes at the end of + // the array without using a branch: + // - If index >= sizeof(Group) then index == index2. + // - Otherwise index2 == self.bucket_mask + 1 + index. + // + // The very last replicated control byte is never actually read because + // we mask the initial index for unaligned loads, but we write it + // anyways because it makes the set_ctrl implementation simpler. + // + // If there are fewer buckets than sizeof(Group) then this code will + // replicate the buckets at the end of the trailing group. For example + // with 2 buckets and a group size of 4, the control bytes will look + // like this: + // ============================================= + // | Real | Replicated | + // --------------------------------------------- + // | [A] | [B] | [EMPTY] | [EMPTY] | [A] | [B] | + // ============================================= + + size_t index2 = ((index - sizeof(Group)) & bucket_mask) + sizeof(Group); + ctrl[index] = value; + ctrl[index2] = value; + } + + ProbeSequence probe_sequence(HashWord hash) const { + return {h1(hash) & bucket_mask, 0}; + } + + size_t proper_insertion_slot(size_t index) const { + // In tables smaller than the group width, trailing control + // bytes outside the range of the table are filled with + // EMPTY entries. These will unfortunately trigger a + // match, but once masked may point to a full bucket that + // is already occupied. We detect this situation here and + // perform a second scan starting at the beginning of the + // table. This second scan is guaranteed to find an empty + // slot (due to the load factor) before hitting the trailing + // control bytes (containing EMPTY). + if (unlikely(is_full(ctrl[index]))) { + return Group::aligned_load(ctrl) + .mask_empty_or_deleted() + .lowest_set_bit_nonzero(); + } + return index; + } + + size_t find_insert_slot(HashWord hash) const { + ProbeSequence seq = probe_sequence(hash); + while (true) { + Group group = Group::load(&ctrl[seq.position]); + BitMask empty_slot = group.mask_empty_or_deleted(); + if (empty_slot.any_bit_set()) { + size_t result = + (seq.position + empty_slot.lowest_set_bit_nonzero()) & bucket_mask; + return proper_insertion_slot(result); + } + seq.move_next(bucket_mask); + } + } + + template + Bucket find_or_insert_with_hash(HashWord hash, const T &value, Equal eq, + Hasher hasher) { + CtrlWord h2_hash = h2(hash); + ProbeSequence seq = probe_sequence(hash); + while (true) { + Group group = Group::load(&ctrl[seq.position]); + for (size_t bit : group.match_byte(h2_hash)) { + size_t index = (seq.position + bit) & bucket_mask; + auto bucket = bucket_at(index); + if (likely(eq(*bucket, value))) + return bucket; + } + + if constexpr (ENABLE_DELETION) { + BitMask empty_slot = group.mask_empty(); + if (likely(empty_slot.any_bit_set())) { + size_t index = find_insert_slot(hash); + return insert_at(index, hash, value, hasher); + } + } else { + BitMask empty_slot = group.mask_empty_or_deleted(); + if (likely(empty_slot.any_bit_set())) { + size_t index = (seq.position + empty_slot.lowest_set_bit_nonzero()) & + bucket_mask; + index = proper_insertion_slot(index); + return insert_at(index, hash, value, hasher); + } + } + + seq.move_next(bucket_mask); + } + } + + template + Bucket find_with_hash(HashWord hash, const T &value, Equal eq) const { + CtrlWord h2_hash = h2(hash); + ProbeSequence seq = probe_sequence(hash); + while (true) { + Group group = Group::load(&ctrl[seq.position]); + for (size_t bit : group.match_byte(h2_hash)) { + size_t index = (seq.position + bit) & bucket_mask; + auto bucket = bucket_at(index); + if (likely(eq(*bucket, value))) + return bucket; + } + + if constexpr (ENABLE_DELETION) { + if (likely(group.mask_empty().any_bit_set())) + return {nullptr, 0}; + } else { + // Only EMPTY will appear; no need to distingush EMPTY and DELETED. + if (likely(group.mask_empty_or_deleted().any_bit_set())) + return {nullptr, 0}; + } + + seq.move_next(bucket_mask); + } + } + + void set_ctrl_h2(size_t index, HashWord hash) { set_ctrl(index, h2(hash)); } + + CtrlWord replace_ctrl_h2(size_t index, HashWord hash) { + CtrlWord prev = ctrl[index]; + set_ctrl_h2(index, hash); + return prev; + } + + bool is_bucket_full(size_t index) const { return is_full(ctrl[index]); } + + struct Slot { + size_t index; + CtrlWord prev_ctrl; + }; + + Slot prepare_insert_slot(HashWord hash) { + size_t index = find_insert_slot(hash); + CtrlWord prev_ctrl = ctrl[index]; + set_ctrl_h2(index, hash); + return {index, prev_ctrl}; + } + + void prepare_rehash_inplace() { + // convert full to deleted, deleted to empty s.t. we can use + // deleted as an indicator for rehash + for (size_t i = 0; i < num_buckets(); i += sizeof(Group)) { + Group group = Group::aligned_load(&ctrl[i]); + Group converted = group.convert_special_to_empty_and_full_to_deleted(); + converted.aligned_store(&ctrl[i]); + } + + // handle the cases when table size is smaller than group size + if (num_buckets() < sizeof(Group)) { + memcpy(&ctrl[sizeof(Group)], &ctrl[0], num_buckets()); + } else { + inline_memcpy(reinterpret_cast(&ctrl[num_buckets()]), + reinterpret_cast(&ctrl[0]), sizeof(Group)); + } + } + + void record_item_insert_at(size_t index, CtrlWord prev_ctrl, HashWord hash) { + growth_left -= special_is_empty(prev_ctrl) ? 1 : 0; + set_ctrl_h2(index, hash); + items++; + } + + bool is_in_same_group(size_t index, size_t new_index, HashWord hash) const { + size_t probe_position = probe_sequence(hash).position; + size_t position[2] = { + (index - probe_position) & bucket_mask / sizeof(Group), + (new_index - probe_position) & bucket_mask / sizeof(Group), + }; + return position[0] == position[1]; + } + + bool is_empty_singleton() const { return ctrl == CONST_EMPTY_GROUP.data; } + + RawTable prepare_resize(size_t capacity) { + RawTable new_table = RawTable::with_capacity(capacity); + if (likely(new_table.ctrl != nullptr)) { + new_table.growth_left -= items; + new_table.items += items; + } + return new_table; + } + + template void rehash_in_place(Hasher hasher) { + prepare_rehash_inplace(); + + for (size_t idx = 0; idx < num_buckets(); ++idx) { + if (ctrl[idx] != DELETED) { + continue; + } + + Bucket bucket = bucket_at(idx); + + while (true) { + HashWord hash = hasher(*bucket_at(idx)); + size_t new_idx = find_insert_slot(hash); + // Probing works by scanning through all of the control + // bytes in groups, which may not be aligned to the group + // size. If both the new and old position fall within the + // same unaligned group, then there is no benefit in moving + // it and we can just continue to the next item. + if (likely(is_in_same_group(idx, new_idx, hash))) { + set_ctrl_h2(idx, hash); + break; // continue outer loop + } + + Bucket new_bucket = bucket_at(new_idx); + CtrlWord prev_ctrl = replace_ctrl_h2(new_idx, hash); + + if (prev_ctrl == EMPTY) { + set_ctrl(idx, EMPTY); + *new_bucket = *bucket; + break; // continue outer loop + } + + T temp; + temp = *new_bucket; + *new_bucket = *bucket; + *bucket = temp; + } + } + + growth_left = bucket_mask_to_capacity(bucket_mask) - items; + } + + template bool resize(size_t new_capacity, Hasher hasher) { + RawTable new_table = prepare_resize(new_capacity); + if (new_table.ctrl == nullptr) + return false; + + for (size_t i = 0; i < num_buckets(); ++i) { + if (!is_bucket_full(i)) + continue; + + Bucket bucket = bucket_at(i); + HashWord hash = hasher(*bucket); + + // We can use a simpler version of insert() here since: + // - there are no DELETED entries. + // - we know there is enough space in the table. + // - all elements are unique. + Slot slot = new_table.prepare_insert_slot(hash); + Bucket new_bucket = new_table.bucket_at(slot.index); + + *new_bucket = *bucket; + } + + release(); + *this = new_table; + + return true; + } + + template + bool reserve_rehash(size_t additional, Hasher hasher) { + SafeMemSize checked_new_items = + SafeMemSize{additional} + SafeMemSize{items}; + if (!checked_new_items.valid()) + return false; + + size_t new_items = static_cast(checked_new_items); + size_t full_capacity = bucket_mask_to_capacity(bucket_mask); + + if constexpr (ENABLE_DELETION) { + if (new_items <= full_capacity / 2) { + rehash_in_place(hasher); + return true; + } + } + + if constexpr (ENABLE_RESIZE) { + size_t new_capacity = + new_items > (full_capacity + 1) ? new_items : (full_capacity + 1); + return resize(new_capacity, hasher); + } + + return false; + } + + template + Bucket insert_at(size_t index, HashWord hash, const T &element, + Hasher hasher) { + size_t prev_ctrl = ctrl[index]; + + // If we reach full load factor: + // + // - When deletion is allowed and the found slot is deleted, then it is + // okay to insert. + // - Otherwise, it means it is an empty slot and such insertion will + // invalidate the load factor constraints. Thus, we need to rehash + // the table. There are several cases: + // + // - If deletion is enabled, we may be able to rehash the table in + // place if the total item is less than half of the table. + // - If resizing is enabled, we may be able to grow the table. + // - If either deletion nor resizing is enabled, we will fail + // immediately. + if (unlikely(growth_left == 0 && + (!ENABLE_DELETION || special_is_empty(prev_ctrl)))) { + if (!reserve_rehash(1, hasher)) + return {nullptr, 0}; + index = find_insert_slot(hash); + } + + record_item_insert_at(index, prev_ctrl, hash); + + Bucket bucket = bucket_at(index); + *bucket = element; + + return bucket; + } + +public: + template Bucket insert(const T &element, Hasher hasher) { + HashWord hash = hasher(element); + size_t index = find_insert_slot(hash); + return insert_at(index, hash, element, hasher); + } + + template + Bucket find(const T &element, Hasher hasher, Equal equal) const { + HashWord hash = hasher(element); + return find_with_hash(hash, element, equal); + } + + template + Bucket find_or_insert(const T &element, Hasher hasher, Equal equal) { + HashWord hash = hasher(element); + return find_or_insert_with_hash(hash, element, equal, hasher); + } + + void erase(Bucket item) { + size_t index = bucket_index(item); + size_t index_before = (index - sizeof(Group)) & bucket_mask; + BitMask empty_before = Group::load(&ctrl[index_before]).mask_empty(); + BitMask empty_after = Group::load(&ctrl[index]).mask_empty(); + CtrlWord ctrl; + if (empty_before.leading_zeros() + empty_after.trailing_zeros() >= + sizeof(Group)) { + ctrl = DELETED; + } else { + growth_left++; + ctrl = EMPTY; + }; + set_ctrl(index, ctrl); + items--; + } +}; + +} // namespace __llvm_libc::internal::swisstable +#endif // LLVM_LIBC_SUPPORT_SWISSTABLE_H diff --git a/libc/src/__support/swisstable/CMakeLists.txt b/libc/src/__support/swisstable/CMakeLists.txt --- a/libc/src/__support/swisstable/CMakeLists.txt +++ b/libc/src/__support/swisstable/CMakeLists.txt @@ -13,4 +13,13 @@ common.h DEPENDS libc.src.__support.builtin_wrappers +) + +add_header_library( + safe_mem_size + HDRS + safe_mem_size.h + DEPENDS + libc.src.__support.CPP.limits + libc.src.__support.common ) \ No newline at end of file diff --git a/libc/src/__support/swisstable/safe_mem_size.h b/libc/src/__support/swisstable/safe_mem_size.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/swisstable/safe_mem_size.h @@ -0,0 +1,46 @@ +//===-- Memory Size Checker for SwissTable ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SUPPORT_SWISSTABLE_SAFE_MEM_SIZE_H +#define LLVM_LIBC_SUPPORT_SWISSTABLE_SAFE_MEM_SIZE_H +#include "src/__support/CPP/limits.h" +#include "src/__support/common.h" +#include +#include +namespace __llvm_libc::internal { +// Limit memory size to the max of ssize_t +class SafeMemSize { +private: + static constexpr size_t MAX_MEM_SIZE = + static_cast(__llvm_libc::cpp::numeric_limits::max()); + ssize_t value; + explicit SafeMemSize(ssize_t value) : value(value) {} + +public: + explicit SafeMemSize(size_t value) + : value(value <= MAX_MEM_SIZE ? static_cast(value) : -1) {} + operator size_t() { return static_cast(value); } + bool valid() { return value >= 0; } + SafeMemSize operator+(const SafeMemSize &other) { + ssize_t result; + if (unlikely((value | other.value) < 0)) + result = -1; + result = value + other.value; + return SafeMemSize{result}; + } + SafeMemSize operator*(const SafeMemSize &other) { + ssize_t result; + if (unlikely((value | other.value) < 0)) + result = -1; + if (unlikely(__builtin_mul_overflow(value, other.value, &result))) + result = -1; + return SafeMemSize{result}; + } +}; +} // namespace __llvm_libc::internal +#endif // LLVM_LIBC_SUPPORT_SWISSTABLE_SAFE_MEM_SIZE_H diff --git a/libc/test/src/__support/swisstable/CMakeLists.txt b/libc/test/src/__support/swisstable/CMakeLists.txt --- a/libc/test/src/__support/swisstable/CMakeLists.txt +++ b/libc/test/src/__support/swisstable/CMakeLists.txt @@ -33,3 +33,37 @@ libc.src.string.memcmp libc.src.__support.swisstable.dispatch ) + + +add_libc_unittest( + safe_mem_size_test + SUITE + libc_support_unittests + SRCS + safe_mem_size_test.cpp + DEPENDS + libc.src.__support.swisstable.safe_mem_size +) + +add_header_library( + test_utils + HDRS + test_utils.h + DEPENDS + libc.include.stdlib + libc.src.string.memmove +) + +add_libc_unittest( + hashtable_test + SUITE + libc_support_unittests + SRCS + hashtable_test.cpp + DEPENDS + .test_utils + libc.src.__support.swisstable + libc.src.stdlib.rand + libc.src.stdlib.srand + libc.src.sys.random.getrandom +) diff --git a/libc/test/src/__support/swisstable/bitmask_test.cpp b/libc/test/src/__support/swisstable/bitmask_test.cpp --- a/libc/test/src/__support/swisstable/bitmask_test.cpp +++ b/libc/test/src/__support/swisstable/bitmask_test.cpp @@ -1,7 +1,5 @@ #include "src/__support/swisstable/dispatch.h" #include "utils/UnitTest/LibcTest.h" -#include -#include namespace __llvm_libc::internal::swisstable { using ShortBitMask = BitMaskAdaptor; diff --git a/libc/test/src/__support/swisstable/group_test.cpp b/libc/test/src/__support/swisstable/group_test.cpp --- a/libc/test/src/__support/swisstable/group_test.cpp +++ b/libc/test/src/__support/swisstable/group_test.cpp @@ -1,4 +1,3 @@ -#include #if SWISSTABLE_TEST_USE_GENERIC_GROUP #include "src/__support/swisstable/generic.h" #define SWISSTABLE_TEST_SUITE(X) TEST(LlvmLibcSwissTableGroupGeneric, X) @@ -8,7 +7,6 @@ #endif #include "src/string/memcmp.h" #include "utils/UnitTest/LibcTest.h" -#include namespace __llvm_libc::internal::swisstable { diff --git a/libc/test/src/__support/swisstable/hashtable_test.cpp b/libc/test/src/__support/swisstable/hashtable_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/hashtable_test.cpp @@ -0,0 +1,211 @@ +#include "src/__support/builtin_wrappers.h" +#include "src/__support/swisstable.h" +#include "src/stdlib/rand.h" +#include "src/stdlib/srand.h" +#include "src/sys/random/getrandom.h" +#include "test/src/__support/swisstable/test_utils.h" +#include "utils/UnitTest/LibcTest.h" + +namespace __llvm_libc::internal::swisstable { + +enum TestAction { Insert = 0, Find = 1, Delete = 2 }; + +using FixedInsertOnlyTable = RawTable; +using ResizableInsertOnlyTable = RawTable; +using FixedTable = RawTable; +using ResizableTable = RawTable; + +static inline uint64_t sample() { + size_t length = 64 - safe_clz(RAND_MAX); + uint64_t accumulator = 0; + for (size_t i = 0; i < 64; i += length) { + accumulator |= static_cast(rand()) << i; + } + return accumulator; +} + +static inline uint64_t bad_sample(uint64_t x) { return x << 16; } + +static inline TestAction next_move(int max = 3) { + return static_cast(((rand() % max) + max) % max); +} + +static inline bool test_equal(const uint64_t &x, const uint64_t &y) { + return x == y; +} +static inline size_t test_hash(const uint64_t &x) { return x; } + +template +struct LlvmLibcSwissTableTest : testing::Test { + + void init_seed() { + unsigned int seed; + __llvm_libc::getrandom(&seed, sizeof(seed), 0); + __llvm_libc::srand(seed); + } + + void bad_hash_example() { + auto table = Table::with_capacity(16384); + for (uint64_t i = 0; i < 16384; ++i) { + auto t = bad_sample(i); + table.insert(t, test_hash); + } + for (uint64_t i = 0; i < 16384; ++i) { + ASSERT_EQ(*table.find(bad_sample(i), test_hash, test_equal), + bad_sample(i)); + } + for (uint64_t i = 16384; i < 32768; ++i) { + ASSERT_FALSE(static_cast(table.find(i, test_hash, test_equal))); + ASSERT_FALSE( + static_cast(table.find(bad_sample(i), test_hash, test_equal))); + } + table.release(); + } + + void random_walk(size_t attempts, size_t capacity) { + BrutalForceSet set; + auto table = Table::with_capacity(capacity); + for (size_t i = 0; i < attempts; ++i) { + switch (next_move(OpRange)) { + case Insert: { + auto x = sample(); + set.insert(x); + if (x & 1) { + if (!table.find(x, test_hash, test_equal)) { + ASSERT_EQ(*table.insert(x, test_hash), x); + } + } else { + ASSERT_EQ(*table.find_or_insert(x, test_hash, test_equal), x); + } + ASSERT_EQ(*table.find(x, test_hash, test_equal), x); + + break; + } + case Find: { + if (!set.size) + continue; + auto x = sample(); + ASSERT_EQ(static_cast(table.find(x, test_hash, test_equal)), + set.find(x)); + auto index = x % set.size; + ASSERT_EQ(*table.find(set.data[index], test_hash, test_equal), + set.data[index]); + break; + } + case Delete: { + if (!set.size) + continue; + auto x = sample(); + auto index = x % set.size; + auto target = set.data[index]; + set.erase(target); + ASSERT_EQ(*table.find(target, test_hash, test_equal), target); + table.erase(table.find(target, test_hash, test_equal)); + ASSERT_FALSE( + static_cast(table.find(target, test_hash, test_equal))); + break; + } + } + } + for (size_t i = 0; i < set.size; ++i) { + ASSERT_EQ(*table.find(set.data[i], test_hash, test_equal), set.data[i]); + } + for (size_t i = 0; i < attempts; ++i) { + auto x = sample(); + ASSERT_EQ(static_cast(table.find(x, test_hash, test_equal)), + set.find(x)); + } + table.release(); + } +}; + +using LlvmLibcSwissTableFixedInsertOnlyTableTest = + LlvmLibcSwissTableTest; +using LlvmLibcSwissTableResizableInsertOnlyTableTest = + LlvmLibcSwissTableTest; +using LlvmLibcSwissTableFixedTableTest = LlvmLibcSwissTableTest; +using LlvmLibcSwissTableResizableTableTest = + LlvmLibcSwissTableTest; + +TEST_F(LlvmLibcSwissTableFixedInsertOnlyTableTest, RandomWalk) { + init_seed(); + for (size_t i = 0; i < 100; ++i) { + this->random_walk(10000, 10000); + } +} +TEST_F(LlvmLibcSwissTableResizableInsertOnlyTableTest, RandomWalk) { + init_seed(); + for (size_t i = 0; i < 100; ++i) { + this->random_walk(10000, 0); + } +} +TEST_F(LlvmLibcSwissTableFixedTableTest, RandomWalk) { + init_seed(); + for (size_t i = 0; i < 100; ++i) { + this->random_walk(10000, 10000); + } +} +TEST_F(LlvmLibcSwissTableResizableTableTest, RandomWalk) { + init_seed(); + for (size_t i = 0; i < 100; ++i) { + this->random_walk(10000, 0); + } +} + +TEST_F(LlvmLibcSwissTableFixedInsertOnlyTableTest, BadHash) { + bad_hash_example(); +} +TEST_F(LlvmLibcSwissTableResizableInsertOnlyTableTest, BadHash) { + bad_hash_example(); +} +TEST_F(LlvmLibcSwissTableFixedTableTest, BadHash) { bad_hash_example(); } +TEST_F(LlvmLibcSwissTableResizableTableTest, BadHash) { bad_hash_example(); } + +TEST(LlvmLibcSwissTableFixedInsertOnlyTableTest, Pressure) { + auto table = FixedInsertOnlyTable::with_capacity(512); + uint64_t i = 0; + for (; i < 512; ++i) { + if (!table.insert(i, test_hash)) { + break; + } + } + ASSERT_LE(i, static_cast(512)); + table.release(); +} + +TEST(LlvmLibcSwissTableFixedTableTest, InPlaceRehash) { + auto table = FixedTable::with_capacity(512); + uint64_t i = 0; + for (; i < 512; ++i) { + if (!table.insert(i, test_hash)) { + break; + } + } + + uint64_t n = i; + + ASSERT_GE(n, static_cast(128)); + + for (i = 0; i < 128; ++i) { + table.erase(table.find(i, test_hash, test_equal)); + } + + for (i = 1000; i < 1128; ++i) { + table.find_or_insert(i, test_hash, test_equal); + } + + for (i = 0; i < 128; ++i) { + ASSERT_FALSE(static_cast(table.find(i, test_hash, test_equal))); + } + + for (i = n; i < 512; ++i) { + ASSERT_EQ(*table.find(i, test_hash, test_equal), i); + } + + for (i = 1000; i < 1128; ++i) { + ASSERT_EQ(*table.find(i, test_hash, test_equal), i); + } + + table.release(); +} +} // namespace __llvm_libc::internal::swisstable diff --git a/libc/test/src/__support/swisstable/safe_mem_size_test.cpp b/libc/test/src/__support/swisstable/safe_mem_size_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/safe_mem_size_test.cpp @@ -0,0 +1,59 @@ +#include "src/__support/swisstable/safe_mem_size.h" +#include "utils/UnitTest/LibcTest.h" + +namespace __llvm_libc::internal::swisstable { + +static inline constexpr size_t SAFE_MEM_SIZE_TEST_LIMIT = + static_cast(__llvm_libc::cpp::numeric_limits::max()); + +TEST(LlvmLibcSwissTableSafeMemSize, Constuction) { + ASSERT_FALSE(SafeMemSize{static_cast(-1)}.valid()); + ASSERT_FALSE(SafeMemSize{static_cast(-2)}.valid()); + ASSERT_FALSE(SafeMemSize{static_cast(-1024 + 33)}.valid()); + ASSERT_FALSE(SafeMemSize{static_cast(-1024 + 66)}.valid()); + ASSERT_FALSE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT + 1}.valid()); + ASSERT_FALSE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT + 13}.valid()); + + ASSERT_TRUE(SafeMemSize{static_cast(1)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(1024 + 13)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(2048 - 13)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(4096 + 1)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(8192 - 1)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(16384 + 15)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(32768 * 3)}.valid()); + ASSERT_TRUE(SafeMemSize{static_cast(65536 * 13)}.valid()); + ASSERT_TRUE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT}.valid()); + ASSERT_TRUE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT - 1}.valid()); + ASSERT_TRUE(SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT - 13}.valid()); +} + +TEST(LlvmLibcSwissTableSafeMemSize, Addition) { + auto max = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT}; + auto half = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT / 2}; + auto third = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT / 3}; + + ASSERT_TRUE(half.valid()); + ASSERT_TRUE(third.valid()); + ASSERT_TRUE((half + half).valid()); + ASSERT_TRUE((third + third + third).valid()); + ASSERT_TRUE((half + third).valid()); + + ASSERT_FALSE((max + SafeMemSize{static_cast(1)}).valid()); + ASSERT_FALSE((third + third + third + third).valid()); + ASSERT_FALSE((half + half + half).valid()); +} + +TEST(LlvmLibcSwissTableSafeMemSize, Multiplication) { + auto max = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT}; + auto half = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT / 2}; + auto third = SafeMemSize{SAFE_MEM_SIZE_TEST_LIMIT / 3}; + + ASSERT_TRUE((max * SafeMemSize{static_cast(1)}).valid()); + ASSERT_TRUE((max * SafeMemSize{static_cast(0)}).valid()); + + ASSERT_FALSE((max * SafeMemSize{static_cast(2)}).valid()); + ASSERT_FALSE((half * half).valid()); + ASSERT_FALSE((half * SafeMemSize{static_cast(3)}).valid()); + ASSERT_FALSE((third * SafeMemSize{static_cast(4)}).valid()); +} +} // namespace __llvm_libc::internal::swisstable diff --git a/libc/test/src/__support/swisstable/test_utils.h b/libc/test/src/__support/swisstable/test_utils.h new file mode 100644 --- /dev/null +++ b/libc/test/src/__support/swisstable/test_utils.h @@ -0,0 +1,65 @@ +#ifndef LLVM_LIBC_TEST_SUPPORT_SWISSTABLE_TEST_UTILS_H +#define LLVM_LIBC_TEST_SUPPORT_SWISSTABLE_TEST_UTILS_H +#include +#include +#include + +namespace __llvm_libc::internal::swisstable { + +struct BrutalForceSet { + size_t capacity; + size_t size; + uint64_t *data; + + BrutalForceSet() : capacity(16), size(0) { + data = static_cast(calloc(capacity, sizeof(uint64_t))); + } + + uint64_t *binary_search(uint64_t target) { + size_t l = 0, r = size; + while (l < r) { + auto mid = (l + r) / 2; + if (data[mid] == target) { + return &data[mid]; + } + if (data[mid] < target) { + l = mid + 1; + } + if (data[mid] > target) { + r = mid; + } + } + return &data[l]; + } + + void grow() { + data = + static_cast(realloc(data, sizeof(uint64_t) * capacity * 2)); + capacity *= 2; + } + + void insert(uint64_t target) { + if (size == capacity) + grow(); + uint64_t *position = binary_search(target); + if (*position != target) { + __llvm_libc::memmove(position + 1, position, + sizeof(uint64_t) * (data + size - position)); + size++; + *position = target; + } + } + + bool find(uint64_t target) { return target == *binary_search(target); } + + void erase(uint64_t target) { + uint64_t *position = binary_search(target); + if (*position == target) { + __llvm_libc::memmove(position, position + 1, + sizeof(uint64_t) * (data + size - position - 1)); + size--; + } + } +}; +} // namespace __llvm_libc::internal::swisstable +#endif // LLVM_LIBC_TEST_SUPPORT_SWISSTABLE_TEST_UTILS_H