diff --git a/libc/src/__support/RPC/rpc.h b/libc/src/__support/RPC/rpc.h --- a/libc/src/__support/RPC/rpc.h +++ b/libc/src/__support/RPC/rpc.h @@ -36,6 +36,7 @@ PRINT_TO_STDERR = 1, EXIT = 2, TEST_INCREMENT = 3, + TEST_INTERFACE = 4, }; /// A fixed size channel used to communicate between the RPC client and server. @@ -253,7 +254,8 @@ // TODO: This should be move-only. LIBC_INLINE Port(Process &process, uint64_t lane_mask, uint64_t index, uint32_t out) - : process(process), lane_mask(lane_mask), index(index), out(out) {} + : process(process), lane_mask(lane_mask), index(index), out(out), + recieve(false) {} LIBC_INLINE Port(const Port &) = delete; LIBC_INLINE Port &operator=(const Port &) = delete; LIBC_INLINE Port(Port &&) = default; @@ -272,13 +274,20 @@ return process.get_packet(index).header.opcode; } - LIBC_INLINE void close() { process.unlock(lane_mask, index); } + LIBC_INLINE void close() { + // If the server last did a recieve it needs to exchange ownership before + // closing the port. + if (recieve && T) + out = process.invert_outbox(index, out); + process.unlock(lane_mask, index); + } private: Process &process; uint64_t lane_mask; uint64_t index; uint32_t out; + bool recieve; }; /// The RPC client used to make requests to the server. @@ -319,10 +328,16 @@ process.invoke_rpc(fill, process.get_packet(index)); atomic_thread_fence(cpp::MemoryOrder::RELEASE); out = process.invert_outbox(index, out); + recieve = false; } /// Applies \p use to the shared buffer and acknowledges the send. template template LIBC_INLINE void Port::recv(U use) { + // We only exchange ownership of the buffer during a recieve if we are waiting + // for a previous recieve to finish. + if (recieve) + out = process.invert_outbox(index, out); + uint32_t in = process.load_inbox(index); // We need to wait until we own the buffer before receiving. @@ -334,7 +349,7 @@ // Apply the \p use function to read the memory out of the buffer. process.invoke_rpc(use, process.get_packet(index)); - out = process.invert_outbox(index, out); + recieve = true; } /// Combines a send and receive into a single function. diff --git a/libc/test/integration/startup/gpu/CMakeLists.txt b/libc/test/integration/startup/gpu/CMakeLists.txt --- a/libc/test/integration/startup/gpu/CMakeLists.txt +++ b/libc/test/integration/startup/gpu/CMakeLists.txt @@ -36,3 +36,10 @@ SRCS init_fini_array_test.cpp ) + +add_integration_test( + startup_rpc_interface_test + SUITE libc-startup-tests + SRCS + rpc_interface_test.cpp +) diff --git a/libc/test/integration/startup/gpu/rpc_interface_test.cpp b/libc/test/integration/startup/gpu/rpc_interface_test.cpp new file mode 100644 --- /dev/null +++ b/libc/test/integration/startup/gpu/rpc_interface_test.cpp @@ -0,0 +1,43 @@ +//===-- Loader test to check the RPC interface with the loader ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/GPU/utils.h" +#include "src/__support/RPC/rpc_client.h" +#include "test/IntegrationTest/test.h" + +using namespace __llvm_libc; + +// Test to ensure that we can use aribtrary combinations of sends and recieves +// as long as they are mirrored. +static void test_interface(bool end_with_send) { + uint64_t cnt = 0; + rpc::Client::Port port = rpc::client.open(); + port.send([&](rpc::Buffer *buffer) { buffer->data[0] = end_with_send; }); + port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); + port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); + port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); + port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); + port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + if (end_with_send) + port.send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); + else + port.recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port.close(); + + ASSERT_TRUE(cnt == 9 && "Invalid number of increments"); +} + +TEST_MAIN(int argc, char **argv, char **envp) { + test_interface(true); + test_interface(false); + + return 0; +} diff --git a/libc/utils/gpu/loader/Server.h b/libc/utils/gpu/loader/Server.h --- a/libc/utils/gpu/loader/Server.h +++ b/libc/utils/gpu/loader/Server.h @@ -60,6 +60,41 @@ }); break; } + case __llvm_libc::rpc::Opcode::TEST_INTERFACE: { + uint64_t cnt = 0; + bool end_with_recv; + port->recv([&](__llvm_libc::rpc::Buffer *buffer) { + end_with_recv = buffer->data[0]; + }); + port->recv( + [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port->send([&](__llvm_libc::rpc::Buffer *buffer) { + buffer->data[0] = cnt = cnt + 1; + }); + port->recv( + [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port->send([&](__llvm_libc::rpc::Buffer *buffer) { + buffer->data[0] = cnt = cnt + 1; + }); + port->recv( + [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port->recv( + [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port->send([&](__llvm_libc::rpc::Buffer *buffer) { + buffer->data[0] = cnt = cnt + 1; + }); + port->send([&](__llvm_libc::rpc::Buffer *buffer) { + buffer->data[0] = cnt = cnt + 1; + }); + if (end_with_recv) + port->recv( + [&](__llvm_libc::rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + else + port->send([&](__llvm_libc::rpc::Buffer *buffer) { + buffer->data[0] = cnt = cnt + 1; + }); + break; + } default: port->recv([](__llvm_libc::rpc::Buffer *buffer) {}); }