diff --git a/libc/config/linux/api.td b/libc/config/linux/api.td --- a/libc/config/linux/api.td +++ b/libc/config/linux/api.td @@ -301,6 +301,20 @@ ]; } +def CndT : TypeDecl<"cnd_t"> { + // Using a 4-byte field for raw futex data + // instead of a mtx_t field allows using + // FUTEX_WAKE_OP. See implementation of + // cnd_signal and cnd_wait for details. + let Decl = [{ + typedef struct { + unsigned char __queue_status[4]; + void *__waiter_queue_begin; + void *__waiter_queue_end; + } cnd_t; + }]; +} + def MtxT : TypeDecl<"mtx_t"> { let Decl = [{ typedef struct { @@ -316,6 +330,7 @@ def ThreadsAPI : PublicAPI<"threads.h"> { let TypeDeclarations = [ + CndT, MtxT, ThreadStartT, ]; @@ -332,6 +347,11 @@ ]; let Functions = [ + "cnd_broadcast", + "cnd_destroy", + "cnd_init", + "cnd_signal", + "cnd_wait", "mtx_init", "mtx_lock", "mtx_unlock", diff --git a/libc/lib/CMakeLists.txt b/libc/lib/CMakeLists.txt --- a/libc/lib/CMakeLists.txt +++ b/libc/lib/CMakeLists.txt @@ -35,6 +35,11 @@ libc.src.sys.mman.munmap # threads.h entrypoints + libc.src.threads.cnd_broadcast + libc.src.threads.cnd_destroy + libc.src.threads.cnd_init + libc.src.threads.cnd_signal + libc.src.threads.cnd_wait libc.src.threads.mtx_init libc.src.threads.mtx_lock libc.src.threads.mtx_unlock diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td --- a/libc/spec/stdc.td +++ b/libc/spec/stdc.td @@ -8,6 +8,8 @@ RestrictedPtrType CharRestrictedPtr = RestrictedPtrType; ConstType ConstCharRestrictedPtr = ConstType; + NamedType CndTType = NamedType<"cnd_t">; + PtrType CndTTypePtr = PtrType; NamedType MtxTType = NamedType<"mtx_t">; PtrType MtxTTypePtr = PtrType; NamedType ThrdStartTType = NamedType<"thrd_start_t">; @@ -269,6 +271,7 @@ "threads.h", [], // Macros [ + CndTType, MtxTType, ThrdStartTType, ThrdTType, @@ -284,6 +287,42 @@ EnumeratedNameValue<"thrd_nomem">, ], [ + FunctionSpec< + "cnd_broadcast", + RetValSpec, + [ + ArgSpec, + ] + >, + FunctionSpec< + "cnd_destroy", + RetValSpec, + [ + ArgSpec, + ] + >, + FunctionSpec< + "cnd_init", + RetValSpec, + [ + ArgSpec, + ] + >, + FunctionSpec< + "cnd_signal", + RetValSpec, + [ + ArgSpec, + ] + >, + FunctionSpec< + "cnd_wait", + RetValSpec, + [ + ArgSpec, + ArgSpec, + ] + >, FunctionSpec< "mtx_init", RetValSpec, diff --git a/libc/src/threads/CMakeLists.txt b/libc/src/threads/CMakeLists.txt --- a/libc/src/threads/CMakeLists.txt +++ b/libc/src/threads/CMakeLists.txt @@ -36,3 +36,38 @@ DEPENDS .${LIBC_TARGET_OS}.mtx_unlock ) + +add_entrypoint_object( + cnd_init + ALIAS + DEPENDS + .${LIBC_TARGET_OS}.cnd_init +) + +add_entrypoint_object( + cnd_destroy + ALIAS + DEPENDS + .${LIBC_TARGET_OS}.cnd_destroy +) + +add_entrypoint_object( + cnd_signal + ALIAS + DEPENDS + .${LIBC_TARGET_OS}.cnd_signal +) + +add_entrypoint_object( + cnd_broadcast + ALIAS + DEPENDS + .${LIBC_TARGET_OS}.cnd_broadcast +) + +add_entrypoint_object( + cnd_wait + ALIAS + DEPENDS + .${LIBC_TARGET_OS}.cnd_wait +) diff --git a/libc/src/threads/cnd_broadcast.h b/libc/src/threads/cnd_broadcast.h new file mode 100644 --- /dev/null +++ b/libc/src/threads/cnd_broadcast.h @@ -0,0 +1,20 @@ +//===-- Implementation header for cnd_broadcast function --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_THREADS_CND_BROADCAST_H +#define LLVM_LIBC_SRC_THREADS_CND_BROADCAST_H + +#include "include/threads.h" + +namespace __llvm_libc { + +int cnd_broadcast(cnd_t *cnd); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_THREADS_CND_BROADCAST_H diff --git a/libc/src/threads/cnd_destroy.h b/libc/src/threads/cnd_destroy.h new file mode 100644 --- /dev/null +++ b/libc/src/threads/cnd_destroy.h @@ -0,0 +1,20 @@ +//===-- Implementation header for cnd_destroy function ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_THREADS_CND_DESTROY_H +#define LLVM_LIBC_SRC_THREADS_CND_DESTROY_H + +#include "include/threads.h" + +namespace __llvm_libc { + +void cnd_destroy(cnd_t *cnd); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_THREADS_CND_DESTROY_H diff --git a/libc/src/threads/cnd_init.h b/libc/src/threads/cnd_init.h new file mode 100644 --- /dev/null +++ b/libc/src/threads/cnd_init.h @@ -0,0 +1,20 @@ +//===-- Implementation header for cnd_init function -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_THREADS_CND_INIT_H +#define LLVM_LIBC_SRC_THREADS_CND_INIT_H + +#include "include/threads.h" + +namespace __llvm_libc { + +int cnd_init(cnd_t *cnd); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_THREADS_CND_INIT_H diff --git a/libc/src/threads/cnd_signal.h b/libc/src/threads/cnd_signal.h new file mode 100644 --- /dev/null +++ b/libc/src/threads/cnd_signal.h @@ -0,0 +1,20 @@ +//===-- Implementation header for cnd_signal function -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_THREADS_CND_SIGNAL_H +#define LLVM_LIBC_SRC_THREADS_CND_SIGNAL_H + +#include "include/threads.h" + +namespace __llvm_libc { + +int cnd_signal(cnd_t *cnd); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_THREADS_CND_SIGNAL_H diff --git a/libc/src/threads/cnd_wait.h b/libc/src/threads/cnd_wait.h new file mode 100644 --- /dev/null +++ b/libc/src/threads/cnd_wait.h @@ -0,0 +1,20 @@ +//===-- Implementation header for cnd_wait function -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_THREADS_CND_WAIT_H +#define LLVM_LIBC_SRC_THREADS_CND_WAIT_H + +#include "include/threads.h" + +namespace __llvm_libc { + +int cnd_wait(cnd_t *cnd, mtx_t *mtx); + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_THREADS_CND_WAIT_H diff --git a/libc/src/threads/linux/CMakeLists.txt b/libc/src/threads/linux/CMakeLists.txt --- a/libc/src/threads/linux/CMakeLists.txt +++ b/libc/src/threads/linux/CMakeLists.txt @@ -87,3 +87,65 @@ libc.include.sys_syscall libc.include.threads ) + +add_entrypoint_object( + cnd_init + SRCS + cnd_init.cpp + HDRS + ../cnd_init.h + DEPENDS + .threads_utils + libc.include.threads +) + +add_entrypoint_object( + cnd_destroy + SRCS + cnd_destroy.cpp + HDRS + ../cnd_destroy.h + DEPENDS + libc.include.threads +) + +add_entrypoint_object( + cnd_signal + SRCS + cnd_signal.cpp + HDRS + ../cnd_signal.h + DEPENDS + .threads_utils + libc.config.linux.linux_syscall_h + libc.include.sys_syscall + libc.include.threads +) + +add_entrypoint_object( + cnd_broadcast + SRCS + cnd_broadcast.cpp + HDRS + ../cnd_broadcast.h + DEPENDS + .threads_utils + libc.config.linux.linux_syscall_h + libc.include.sys_syscall + libc.include.threads +) + +add_entrypoint_object( + cnd_wait + SRCS + cnd_wait.cpp + HDRS + ../cnd_wait.h + DEPENDS + .threads_utils + libc.config.linux.linux_syscall_h + libc.include.sys_syscall + libc.include.threads + libc.src.threads.mtx_lock + libc.src.threads.mtx_unlock +) diff --git a/libc/src/threads/linux/cnd_broadcast.cpp b/libc/src/threads/linux/cnd_broadcast.cpp new file mode 100644 --- /dev/null +++ b/libc/src/threads/linux/cnd_broadcast.cpp @@ -0,0 +1,62 @@ +//===-- Linux implementation of the cnd_broadcast function ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "config/linux/syscall.h" // For syscall functions. +#include "include/sys/syscall.h" // For syscall numbers. +#include "include/threads.h" // For cnd_t and other definitions. +#include "src/__support/common.h" +#include "src/threads/linux/thread_utils.h" // For the status enums. + +#include +#include + +namespace __llvm_libc { + +int LLVM_LIBC_ENTRYPOINT(cnd_broadcast)(cnd_t *cnd) { + unsigned int queue_free = CQS_Free; + while (true) { + auto *queue_status = reinterpret_cast(cnd->__queue_status); + + if (!atomic_compare_exchange_strong(queue_status, &queue_free, CQS_Busy)) { + // If the waiter queue is busy wait for it. + __llvm_libc::syscall(SYS_futex, queue_status, FUTEX_WAIT_PRIVATE, + CQS_Busy, 0, 0, 0); + + // Try to re-acquire access to the wait queue. + continue; + } + + while (true) { + auto *first_waiter = + reinterpret_cast(cnd->__waiter_queue_begin); + if (first_waiter == nullptr) { + // There are no waiters so unblock the waiter queue and return. + *queue_status = CQS_Free; + __llvm_libc::syscall(SYS_futex, queue_status, FUTEX_WAKE_PRIVATE, + INT_MAX, 0, 0, 0); + return thrd_success; + } + + cnd->__waiter_queue_begin = first_waiter->next; + + auto *last_waiter = + reinterpret_cast(cnd->__waiter_queue_end); + if (last_waiter == first_waiter) + last_waiter = nullptr; + + // Wake up the first waiter. + auto *waiter_status = + reinterpret_cast(&first_waiter->status); + *waiter_status = CWS_WaitOver; + __llvm_libc::syscall(SYS_futex, waiter_status, FUTEX_WAKE_PRIVATE, 1, 0, + 0, 0); + } + } +} + +} // namespace __llvm_libc diff --git a/libc/src/threads/linux/cnd_destroy.cpp b/libc/src/threads/linux/cnd_destroy.cpp new file mode 100644 --- /dev/null +++ b/libc/src/threads/linux/cnd_destroy.cpp @@ -0,0 +1,16 @@ +//===-- Linux implementation of the cnd_destroy function ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "include/threads.h" // For cnd_t definition. +#include "src/__support/common.h" + +namespace __llvm_libc { + +void LLVM_LIBC_ENTRYPOINT(cnd_destroy)(cnd_t *) {} + +} // namespace __llvm_libc diff --git a/libc/src/threads/linux/cnd_init.cpp b/libc/src/threads/linux/cnd_init.cpp new file mode 100644 --- /dev/null +++ b/libc/src/threads/linux/cnd_init.cpp @@ -0,0 +1,21 @@ +//===-- Linux implementation of the cnd_init function ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "include/threads.h" // For cnd_t and other definition. +#include "src/__support/common.h" +#include "src/threads/linux/thread_utils.h" // For the status enums. + +namespace __llvm_libc { + +int LLVM_LIBC_ENTRYPOINT(cnd_init)(cnd_t *cnd) { + cnd->__waiter_queue_begin = cnd->__waiter_queue_end = nullptr; + *reinterpret_cast(cnd->__queue_status) = CQS_Free; + return thrd_success; +} + +} // namespace __llvm_libc diff --git a/libc/src/threads/linux/cnd_signal.cpp b/libc/src/threads/linux/cnd_signal.cpp new file mode 100644 --- /dev/null +++ b/libc/src/threads/linux/cnd_signal.cpp @@ -0,0 +1,63 @@ +//===-- Linux implementation of the cnd_signal function -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "config/linux/syscall.h" // For syscall functions. +#include "include/sys/syscall.h" // For syscall numbers. +#include "include/threads.h" // For cnd_t and other definitions. +#include "src/__support/common.h" +#include "src/threads/linux/thread_utils.h" // For the status enums. + +#include +#include + +namespace __llvm_libc { + +int LLVM_LIBC_ENTRYPOINT(cnd_signal)(cnd_t *cnd) { + unsigned int queue_free = CQS_Free; + while (true) { + auto *queue_status = reinterpret_cast(cnd->__queue_status); + + if (!atomic_compare_exchange_strong(queue_status, &queue_free, CQS_Busy)) { + // If the waiter queue is busy wait for it. + __llvm_libc::syscall(SYS_futex, queue_status, FUTEX_WAIT_PRIVATE, + CQS_Busy, 0, 0, 0); + + // Try to re-acquire access to the wait queue. + continue; + } + + auto *first_waiter = + reinterpret_cast(cnd->__waiter_queue_begin); + if (first_waiter == nullptr) { + // There are no waiters so unblock the waiter queue and return. + *queue_status = CQS_Free; + __llvm_libc::syscall(SYS_futex, queue_status, FUTEX_WAKE_PRIVATE, INT_MAX, + 0, 0, 0); + return thrd_success; + } + + cnd->__waiter_queue_begin = first_waiter->next; + + auto *last_waiter = reinterpret_cast(cnd->__waiter_queue_end); + if (last_waiter == first_waiter) + cnd->__waiter_queue_end = nullptr; + + *queue_status = CQS_Free; + // Unblock the waiter queue and wake the first waiter both with one + // atomic futex operation. + __llvm_libc::syscall( + SYS_futex, queue_status, FUTEX_WAKE_OP_PRIVATE, + INT_MAX, // Unblock the waiter queue. + 1, // Wake the first watier. + &first_waiter->status, + FUTEX_OP(FUTEX_OP_SET, CWS_WaitOver, FUTEX_OP_CMP_EQ, CWS_Waiting)); + return thrd_success; + } +} + +} // namespace __llvm_libc diff --git a/libc/src/threads/linux/cnd_wait.cpp b/libc/src/threads/linux/cnd_wait.cpp new file mode 100644 --- /dev/null +++ b/libc/src/threads/linux/cnd_wait.cpp @@ -0,0 +1,71 @@ +//===-- Linux implementation of the cnd_wait function ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "config/linux/syscall.h" // For syscall functions. +#include "include/sys/syscall.h" // For syscall numbers. +#include "include/threads.h" // For cnd_t definition. +#include "src/__support/common.h" +#include "src/threads/linux/thread_utils.h" // For the status enums. +#include "src/threads/mtx_lock.h" +#include "src/threads/mtx_unlock.h" + +#include +#include + +namespace __llvm_libc { + +int LLVM_LIBC_ENTRYPOINT(cnd_wait)(cnd_t *cnd, mtx_t *mtx) { + CndWaiter waiter; + unsigned int queue_free = CQS_Free; + while (true) { + auto *queue_status = reinterpret_cast(cnd->__queue_status); + + if (!atomic_compare_exchange_strong(queue_status, &queue_free, CQS_Busy)) { + // If the waiter queue is busy wait for it. + __llvm_libc::syscall(SYS_futex, queue_status, FUTEX_WAIT_PRIVATE, + CQS_Busy, 0, 0, 0); + + // Try to re-acquire access to the wait queue. + continue; + } + + if (cnd->__waiter_queue_begin == nullptr) { + cnd->__waiter_queue_begin = &waiter; + } else { + auto *last_waiter = + reinterpret_cast(cnd->__waiter_queue_end); + last_waiter->next = &waiter; + } + cnd->__waiter_queue_end = &waiter; + + // Unblock the wait queue. + *queue_status = CQS_Free; + __llvm_libc::syscall(SYS_futex, queue_status, FUTEX_WAKE_PRIVATE, INT_MAX, + 0, 0, 0); + break; + } + + // Unlock the mutex and wait to be signalled. + if (__llvm_libc::mtx_unlock(mtx) != thrd_success) { + return thrd_error; + } + + // Since the waiter has already been enqueued, this futex wait will + // return either by: + // 1. A real wake up, or + // 2. If already signalled by this point, then the futex syscall + // will return immediately as the waiter status will not be + // CWS_Waiting. + __llvm_libc::syscall(SYS_futex, &waiter.status, FUTEX_WAIT_PRIVATE, + CWS_Waiting, 0, 0, 0); + + return __llvm_libc::mtx_lock(mtx); + return thrd_success; +} + +} // namespace __llvm_libc diff --git a/libc/src/threads/linux/thread_utils.h b/libc/src/threads/linux/thread_utils.h --- a/libc/src/threads/linux/thread_utils.h +++ b/libc/src/threads/linux/thread_utils.h @@ -27,6 +27,15 @@ // made only if the mutex status is `MutexStatus::Waiting`. enum MutexStatus : uint32_t { MS_Free, MS_Locked, MS_Waiting }; +enum CndQueueStatus : unsigned int { CQS_Free, CQS_Busy }; + +enum CndWaiterStatus : unsigned int { CWS_WaitOver, CWS_Waiting }; + +struct CndWaiter { + CndWaiterStatus status = CWS_Waiting; + CndWaiter *next = nullptr; +}; + static_assert(sizeof(atomic_uint) == 4, "Size of the `atomic_uint` type is not 4 bytes on your platform. " "The implementation of the standard threads library for linux " diff --git a/libc/test/src/threads/CMakeLists.txt b/libc/test/src/threads/CMakeLists.txt --- a/libc/test/src/threads/CMakeLists.txt +++ b/libc/test/src/threads/CMakeLists.txt @@ -28,3 +28,23 @@ libc.src.threads.thrd_create libc.src.threads.thrd_join ) + +add_libc_unittest( + cnd_test + SUITE + libc_threads_unittests + SRCS + cnd_test.cpp + DEPENDS + libc.include.threads + libc.src.threads.cnd_broadcast + libc.src.threads.cnd_destroy + libc.src.threads.cnd_init + libc.src.threads.cnd_signal + libc.src.threads.cnd_wait + libc.src.threads.mtx_init + libc.src.threads.mtx_lock + libc.src.threads.mtx_unlock + libc.src.threads.thrd_create + libc.src.threads.thrd_join +) diff --git a/libc/test/src/threads/cnd_test.cpp b/libc/test/src/threads/cnd_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/src/threads/cnd_test.cpp @@ -0,0 +1,275 @@ +//===-- Unittests for cnd_t -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "include/threads.h" +#include "src/threads/cnd_broadcast.h" +#include "src/threads/cnd_destroy.h" +#include "src/threads/cnd_init.h" +#include "src/threads/cnd_signal.h" +#include "src/threads/cnd_wait.h" +#include "src/threads/mtx_init.h" +#include "src/threads/mtx_lock.h" +#include "src/threads/mtx_unlock.h" +#include "src/threads/thrd_create.h" +#include "src/threads/thrd_join.h" +#include "utils/UnitTest/Test.h" + +#include + +static constexpr int START = 1; +static constexpr int WAITING = 2; +static constexpr int SIGNALLED = 3; +static constexpr int DONE = 4; +static constexpr int FAIL = 5; + +static mtx_t mtx; +static int shared_status; +static cnd_t cnd; + +// This function will keep signalling until it sees that +// the waiter is waiting. +static int simple_sig_func(void *) { + __llvm_libc::mtx_lock(&mtx); + if (shared_status != WAITING) { + // Wait for the waiter to signal. + __llvm_libc::cnd_wait(&cnd, &mtx); + // The expectation is that the waiter will only signal after + // changing the status to waiting. If not, it is an error. + if (shared_status != WAITING) { + __llvm_libc::mtx_unlock(&mtx); + return FAIL; + } + } + + shared_status = SIGNALLED; + __llvm_libc::mtx_unlock(&mtx); + __llvm_libc::cnd_signal(&cnd); + + return 0; + + while (true) { + __llvm_libc::mtx_lock(&mtx); + if (shared_status != DONE) { + shared_status = SIGNALLED; + __llvm_libc::mtx_unlock(&mtx); + // There could be no waiters yet at this point + // but that is OK. A waiter will appear eventually + // and by signalling in a loop, we wake them up. + __llvm_libc::cnd_signal(&cnd); + continue; + } + __llvm_libc::mtx_unlock(&mtx); + return 0; + } +} + +// This function will wait for a signal +static int simple_wait_func(void *) { + __llvm_libc::mtx_lock(&mtx); + shared_status = WAITING; + __llvm_libc::mtx_unlock(&mtx); + __llvm_libc::cnd_signal(&cnd); + + __llvm_libc::mtx_lock(&mtx); + if (shared_status != SIGNALLED) { + __llvm_libc::cnd_wait(&cnd, &mtx); + if (shared_status != SIGNALLED) { + // The signalling function should signal only after setting + // status to signalled. Else it is an error. + __llvm_libc::mtx_unlock(&mtx); + return FAIL; + } + } + + shared_status = DONE; + __llvm_libc::mtx_unlock(&mtx); + return 0; +} + +TEST(CndTest, SimpleWaitAndSignal) { + ASSERT_EQ(__llvm_libc::mtx_init(&mtx, mtx_plain), + static_cast(thrd_success)); + ASSERT_EQ(__llvm_libc::cnd_init(&cnd), static_cast(thrd_success)); + shared_status = START; + + thrd_t sig_thrd, wait_thrd; + ASSERT_EQ(__llvm_libc::thrd_create(&sig_thrd, simple_sig_func, nullptr), + static_cast(thrd_success)); + ASSERT_EQ(__llvm_libc::thrd_create(&wait_thrd, simple_wait_func, nullptr), + static_cast(thrd_success)); + + int retval; + ASSERT_EQ(__llvm_libc::thrd_join(&sig_thrd, &retval), + static_cast(thrd_success)); + ASSERT_EQ(__llvm_libc::thrd_join(&wait_thrd, &retval), + static_cast(thrd_success)); + + __llvm_libc::cnd_destroy(&cnd); + + ASSERT_EQ(shared_status, DONE); +} + +static constexpr int BROADCAST_WAITER_THREADS = 5; +static cnd_t broadcast_cnd, main_thread_signal_cnd; +static mtx_t main_thread_signal_mtx, broadcast_mtx; +static atomic_int count; + +static int broadcast_func(void *) { + // Acquire the broadcast mutex first to ensure that + // all waiters are indeed waiting. + __llvm_libc::mtx_lock(&broadcast_mtx); + __llvm_libc::cnd_broadcast(&broadcast_cnd); + __llvm_libc::mtx_unlock(&broadcast_mtx); + return 0; +} + +static int waiter_func(void *) { + // Acquire the broadcast lock immediately to prevent the + // broadcaster from broadcasting before all threads are ready. + // This also serializes the waiters lining up to wait making + // it absolutely sure that the broadcaster will broadcast + // only after the last waiter starts waiting. + __llvm_libc::mtx_lock(&broadcast_mtx); + ++count; + if (count == BROADCAST_WAITER_THREADS) { + // Signal the main thread that all waiters are now waiting. + __llvm_libc::mtx_lock(&main_thread_signal_mtx); + __llvm_libc::cnd_signal(&main_thread_signal_cnd); + __llvm_libc::mtx_unlock(&main_thread_signal_mtx); + } + __llvm_libc::cnd_wait(&broadcast_cnd, &broadcast_mtx); + --count; + __llvm_libc::mtx_unlock(&broadcast_mtx); + return 0; +} + +TEST(CndTest, Broadcast) { + ASSERT_EQ(__llvm_libc::mtx_init(&broadcast_mtx, mtx_plain), + static_cast(thrd_success)); + ASSERT_EQ(__llvm_libc::mtx_init(&main_thread_signal_mtx, mtx_plain), + static_cast(thrd_success)); + + ASSERT_EQ(__llvm_libc::cnd_init(&cnd), static_cast(thrd_success)); + ASSERT_EQ(__llvm_libc::cnd_init(&main_thread_signal_cnd), + static_cast(thrd_success)); + count = 0; + + // Acquire the main thread signal lock before starting any threads. + // This ensures that the last waiter thread will only signal after the + // main thread starts waiting for a signal. + ASSERT_EQ(__llvm_libc::mtx_lock(&main_thread_signal_mtx), + static_cast(thrd_success)); + + thrd_t waiters[BROADCAST_WAITER_THREADS]; + for (int i = 0; i < BROADCAST_WAITER_THREADS; ++i) { + ASSERT_EQ(__llvm_libc::thrd_create(waiters + i, waiter_func, nullptr), + static_cast(thrd_success)); + } + + // All waiter threads have started so wait for the last one to update + // the count and signal. + __llvm_libc::cnd_wait(&main_thread_signal_cnd, &main_thread_signal_mtx); + ASSERT_EQ((int)count, BROADCAST_WAITER_THREADS); + ASSERT_EQ(__llvm_libc::mtx_unlock(&main_thread_signal_mtx), + static_cast(thrd_success)); + + // Now that we know all waiters are waiting, start the broadcast thread. + thrd_t broadcast_thrd; + ASSERT_EQ(__llvm_libc::thrd_create(&broadcast_thrd, broadcast_func, nullptr), + static_cast(thrd_success)); + + int retval; + for (int i = 0; i < BROADCAST_WAITER_THREADS; ++i) { + ASSERT_EQ(__llvm_libc::thrd_join(waiters + i, &retval), + static_cast(thrd_success)); + } + + ASSERT_EQ(__llvm_libc::thrd_join(&broadcast_thrd, &retval), + static_cast(thrd_success)); + + ASSERT_EQ((int)count, 0); +} + +// A simple cond var class used to setup a suprious wake test. +class CondVar { + cnd_t cnd; + +public: + typedef bool predicate(void); + + CondVar() { __llvm_libc::cnd_init(&cnd); } + ~CondVar() { __llvm_libc::cnd_destroy(&cnd); } + + void wait(mtx_t *mtx, predicate pred) { + // Assumes it is safe to run the predicate. That is, assumes predicate + // access is serialized via the mutex and that it is currently locked + // making it safe to call the predicate. + while (!pred()) { + __llvm_libc::cnd_wait(&cnd, mtx); + } + } + + void notify_one() { __llvm_libc::cnd_signal(&cnd); } +}; + +static CondVar cond_var; +static mtx_t cond_var_mtx; +static int cond_var_counter; + +static constexpr int COND_VAR_START_COUNT = 0; +static constexpr int COND_VAR_FINISH_COUNT = -100; + +bool counter_test() { + // Wait for a large value to allow for a few spurious + // wakes. + return cond_var_counter >= 1000000; +} + +int cond_var_waiter(void *) { + __llvm_libc::mtx_lock(&cond_var_mtx); + // The cond var will be supriously woken up but the predicate test + // will ensure it keeps waiting until it is satisfied. + cond_var.wait(&cond_var_mtx, counter_test); + cond_var_counter = COND_VAR_FINISH_COUNT; + __llvm_libc::mtx_unlock(&cond_var_mtx); + return 0; +} + +int cond_var_notifier(void *) { + while (true) { + __llvm_libc::mtx_lock(&cond_var_mtx); + if (cond_var_counter == COND_VAR_FINISH_COUNT) { + __llvm_libc::mtx_unlock(&cond_var_mtx); + return 0; + } + ++cond_var_counter; + __llvm_libc::mtx_unlock(&cond_var_mtx); + cond_var.notify_one(); + } +} + +TEST(CndTest, SpuriousWake) { + cond_var_counter = COND_VAR_START_COUNT; + + ASSERT_EQ(__llvm_libc::mtx_init(&cond_var_mtx, mtx_plain), + static_cast(thrd_success)); + + thrd_t waiter, notifier; + ASSERT_EQ(__llvm_libc::thrd_create(&waiter, cond_var_waiter, nullptr), + static_cast(thrd_success)); + ASSERT_EQ(__llvm_libc::thrd_create(¬ifier, cond_var_notifier, nullptr), + static_cast(thrd_success)); + + int retval; + ASSERT_EQ(__llvm_libc::thrd_join(&waiter, &retval), + static_cast(thrd_success)); + ASSERT_EQ(__llvm_libc::thrd_join(¬ifier, &retval), + static_cast(thrd_success)); + + ASSERT_EQ(cond_var_counter, COND_VAR_FINISH_COUNT); +}