diff --git a/libc/src/__support/CPP/new.h b/libc/src/__support/CPP/new.h new file mode 100644 --- /dev/null +++ b/libc/src/__support/CPP/new.h @@ -0,0 +1,72 @@ +//===-- Libc specific custom operator new and delete ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_CPP_NEW_H +#define LLVM_LIBC_SRC_SUPPORT_CPP_NEW_H + +#include // For size_t +#include // For malloc, free etc. + +// Defining members in the std namespace is not preferred. But, we do it here +// so that we can use it to define the operator new which takes std::align_val_t +// argument. +namespace std { + +enum class align_val_t : size_t {}; + +} // namespace std + +namespace __llvm_libc { + +class AllocChecker { + bool success = false; + AllocChecker &operator=(bool status) { + success = status; + return *this; + } + +public: + AllocChecker() = default; + operator bool() const { return success; } + + static void *alloc(size_t s, AllocChecker &ac) { + void *mem = ::malloc(s); + ac = (mem != nullptr); + return mem; + } + + static void *aligned_alloc(size_t s, std::align_val_t align, + AllocChecker &ac) { + void *mem = ::aligned_alloc(static_cast(align), s); + ac = (mem != nullptr); + return mem; + } +}; + +} // namespace __llvm_libc + +inline void *operator new(size_t size, __llvm_libc::AllocChecker &ac) noexcept { + return __llvm_libc::AllocChecker::alloc(size, ac); +} + +inline void *operator new(size_t size, std::align_val_t align, + __llvm_libc::AllocChecker &ac) noexcept { + return __llvm_libc::AllocChecker::aligned_alloc(size, align, ac); +} + +inline void *operator new[](size_t size, + __llvm_libc::AllocChecker &ac) noexcept { + return __llvm_libc::AllocChecker::alloc(size, ac); +} + +inline void *operator new[](size_t size, std::align_val_t align, + __llvm_libc::AllocChecker &ac) noexcept { + return __llvm_libc::AllocChecker::aligned_alloc(size, align, ac); +} + +#endif // LLVM_LIBC_SRC_SUPPORT_CPP_NEW_H diff --git a/libc/src/string/CMakeLists.txt b/libc/src/string/CMakeLists.txt --- a/libc/src/string/CMakeLists.txt +++ b/libc/src/string/CMakeLists.txt @@ -16,8 +16,9 @@ HDRS allocating_string_utils.h DEPENDS - libc.include.stdlib .memory_utils.memcpy_implementation + libc.include.stdlib + libc.src.__support.CPP.optional ) add_entrypoint_object( @@ -150,7 +151,9 @@ DEPENDS .memory_utils.memcpy_implementation .string_utils + libc.include.errno libc.include.stdlib + libc.src.errno.errno ) add_entrypoint_object( diff --git a/libc/src/string/allocating_string_utils.h b/libc/src/string/allocating_string_utils.h --- a/libc/src/string/allocating_string_utils.h +++ b/libc/src/string/allocating_string_utils.h @@ -9,24 +9,24 @@ #ifndef LIBC_SRC_STRING_ALLOCATING_STRING_UTILS_H #define LIBC_SRC_STRING_ALLOCATING_STRING_UTILS_H -#include "src/__support/CPP/bitset.h" -#include "src/__support/common.h" -#include "src/string/memory_utils/bzero_implementations.h" -#include "src/string/memory_utils/memcpy_implementations.h" +#include "src/__support/CPP/new.h" +#include "src/__support/CPP/optional.h" +#include "src/string/memory_utils/memcpy_implementations.h" // For string_length #include "src/string/string_utils.h" + #include // For size_t -#include // For malloc namespace __llvm_libc { namespace internal { -inline char *strdup(const char *src) { +cpp::optional strdup(const char *src) { if (src == nullptr) - return nullptr; + return cpp::nullopt; size_t len = string_length(src) + 1; - char *newstr = reinterpret_cast(::malloc(len)); - if (newstr == nullptr) - return nullptr; + AllocChecker ac; + char *newstr = new (ac) char[len]; + if (!ac) + return cpp::nullopt; inline_memcpy(newstr, src, len); return newstr; } diff --git a/libc/src/string/strdup.cpp b/libc/src/string/strdup.cpp --- a/libc/src/string/strdup.cpp +++ b/libc/src/string/strdup.cpp @@ -12,12 +12,18 @@ #include "src/__support/common.h" +#include #include namespace __llvm_libc { LLVM_LIBC_FUNCTION(char *, strdup, (const char *src)) { - return internal::strdup(src); + auto dup = internal::strdup(src); + if (dup) + return *dup; + if (src != nullptr) + errno = ENOMEM; + return nullptr; } } // namespace __llvm_libc diff --git a/libc/src/unistd/linux/getcwd.cpp b/libc/src/unistd/linux/getcwd.cpp --- a/libc/src/unistd/linux/getcwd.cpp +++ b/libc/src/unistd/linux/getcwd.cpp @@ -44,12 +44,12 @@ char pathbuf[PATH_MAX]; if (!getcwd_syscall(pathbuf, PATH_MAX)) return nullptr; - char *cwd = internal::strdup(pathbuf); - if (cwd == nullptr) { + auto cwd = internal::strdup(pathbuf); + if (!cwd) { errno = ENOMEM; return nullptr; } - return cwd; + return *cwd; } else if (size == 0) { errno = EINVAL; return nullptr; diff --git a/libc/test/src/string/CMakeLists.txt b/libc/test/src/string/CMakeLists.txt --- a/libc/test/src/string/CMakeLists.txt +++ b/libc/test/src/string/CMakeLists.txt @@ -141,8 +141,10 @@ SRCS strdup_test.cpp DEPENDS + libc.include.errno libc.include.stdlib libc.src.string.strdup + libc.src.errno.errno ) add_libc_unittest( diff --git a/libc/test/src/string/strdup_test.cpp b/libc/test/src/string/strdup_test.cpp --- a/libc/test/src/string/strdup_test.cpp +++ b/libc/test/src/string/strdup_test.cpp @@ -8,12 +8,17 @@ #include "src/string/strdup.h" #include "utils/UnitTest/Test.h" + +#include #include TEST(LlvmLibcStrDupTest, EmptyString) { const char *empty = ""; + errno = 0; char *result = __llvm_libc::strdup(empty); + ASSERT_EQ(errno, 0); + ASSERT_NE(result, static_cast(nullptr)); ASSERT_NE(empty, const_cast(result)); ASSERT_STREQ(empty, result); @@ -23,7 +28,9 @@ TEST(LlvmLibcStrDupTest, AnyString) { const char *abc = "abc"; + errno = 0; char *result = __llvm_libc::strdup(abc); + ASSERT_EQ(errno, 0); ASSERT_NE(result, static_cast(nullptr)); ASSERT_NE(abc, const_cast(result)); @@ -32,8 +39,9 @@ } TEST(LlvmLibcStrDupTest, NullPtr) { - + errno = 0; char *result = __llvm_libc::strdup(nullptr); + ASSERT_EQ(errno, 0); ASSERT_EQ(result, static_cast(nullptr)); }