diff --git a/libc/src/__support/CPP/type_traits.h b/libc/src/__support/CPP/type_traits.h --- a/libc/src/__support/CPP/type_traits.h +++ b/libc/src/__support/CPP/type_traits.h @@ -57,6 +57,11 @@ template using remove_reference_t = typename remove_reference::type; +template struct remove_pointer : type_identity {}; +template struct remove_pointer : type_identity {}; +template struct remove_pointer : type_identity {}; +template using remove_pointer_t = typename remove_pointer::type; + template struct add_rvalue_reference : type_identity {}; template struct remove_cvref { diff --git a/libc/src/__support/OSUtil/gpu/quick_exit.cpp b/libc/src/__support/OSUtil/gpu/quick_exit.cpp --- a/libc/src/__support/OSUtil/gpu/quick_exit.cpp +++ b/libc/src/__support/OSUtil/gpu/quick_exit.cpp @@ -18,9 +18,7 @@ void quick_exit(int status) { rpc::Client::Port port = rpc::client.open(); - port.send([&](rpc::Buffer *buffer) { - reinterpret_cast(buffer->data)[0] = status; - }); + port.send([=](int *buffer) { *buffer = status; }); port.close(); #if defined(LIBC_TARGET_ARCH_IS_NVPTX) diff --git a/libc/src/__support/RPC/rpc.h b/libc/src/__support/RPC/rpc.h --- a/libc/src/__support/RPC/rpc.h +++ b/libc/src/__support/RPC/rpc.h @@ -241,26 +241,34 @@ } /// Invokes a function accross every active buffer across the total lane size. - LIBC_INLINE void invoke_rpc(cpp::function fn, - Packet &packet) { + template + LIBC_INLINE void invoke_rpc(cpp::function fn, Packet &packet) { + static_assert(cpp::is_pointer_v, "Argument must be a pointer"); + static_assert(sizeof(cpp::remove_pointer_t) <= sizeof(Buffer::data), + "Argument does not fit inside the packet"); if constexpr (is_process_gpu()) { - fn(&packet.payload.slot[gpu::get_lane_id()]); + fn(reinterpret_cast(packet.payload.slot[gpu::get_lane_id()].data)); } else { for (uint32_t i = 0; i < lane_size; i += gpu::get_lane_size()) if (packet.header.mask & 1ul << i) - fn(&packet.payload.slot[i]); + fn(reinterpret_cast(packet.payload.slot[i].data)); } } /// Alternate version that also provides the index of the current lane. - LIBC_INLINE void invoke_rpc(cpp::function fn, + template + LIBC_INLINE void invoke_rpc(cpp::function fn, Packet &packet) { + static_assert(cpp::is_pointer_v, "Argument must be a pointer"); + static_assert(sizeof(cpp::remove_pointer_t) <= sizeof(Buffer::data), + "Argument does not fit inside the packet"); if constexpr (is_process_gpu()) { - fn(&packet.payload.slot[gpu::get_lane_id()], gpu::get_lane_id()); + fn(reinterpret_cast(packet.payload.slot[gpu::get_lane_id()].data), + gpu::get_lane_id()); } else { for (uint32_t i = 0; i < lane_size; i += gpu::get_lane_size()) if (packet.header.mask & 1ul << i) - fn(&packet.payload.slot[i], i); + fn(reinterpret_cast(packet.payload.slot[i].data), i); } } @@ -314,8 +322,8 @@ friend class cpp::optional>; public: - template LIBC_INLINE void recv(U use); - template LIBC_INLINE void send(F fill); + template LIBC_INLINE void recv(U &&use); + template LIBC_INLINE void send(F &&fill); template LIBC_INLINE void send_and_recv(F fill, U use); template LIBC_INLINE void recv_and_send(W work); @@ -369,7 +377,9 @@ }; /// Applies \p fill to the shared buffer and initiates a send operation. -template template LIBC_INLINE void Port::send(F fill) { +template +template +LIBC_INLINE void Port::send(F &&fill) { uint32_t in = process.load_inbox(index); // We need to wait until we own the buffer before sending. @@ -379,14 +389,16 @@ } // Apply the \p fill function to initialize the buffer and release the memory. - process.invoke_rpc(fill, process.get_packet(index)); + process.template invoke_rpc>(fill, process.get_packet(index)); atomic_thread_fence(cpp::MemoryOrder::RELEASE); out = process.invert_outbox(index, out); receive = false; } /// Applies \p use to the shared buffer and acknowledges the send. -template template LIBC_INLINE void Port::recv(U use) { +template +template +LIBC_INLINE void Port::recv(U &&use) { // We only exchange ownership of the buffer during a receive if we are waiting // for a previous receive to finish. if (receive) @@ -402,7 +414,7 @@ atomic_thread_fence(cpp::MemoryOrder::ACQUIRE); // Apply the \p use function to read the memory out of the buffer. - process.invoke_rpc(use, process.get_packet(index)); + process.template invoke_rpc>(use, process.get_packet(index)); receive = true; } diff --git a/libc/src/__support/RPC/rpc_util.h b/libc/src/__support/RPC/rpc_util.h --- a/libc/src/__support/RPC/rpc_util.h +++ b/libc/src/__support/RPC/rpc_util.h @@ -9,6 +9,7 @@ #ifndef LLVM_LIBC_SRC_SUPPORT_RPC_RPC_UTILS_H #define LLVM_LIBC_SRC_SUPPORT_RPC_RPC_UTILS_H +#include "src/__support/CPP/type_traits.h" #include "src/__support/GPU/utils.h" #include "src/__support/macros/attributes.h" #include "src/__support/macros/properties/architectures.h" @@ -69,6 +70,20 @@ return x < y ? y : x; } +namespace internal { +template +Arg first_arg_helper(Ret (F::*)(Arg, Args...)); + +template +Arg first_arg_helper(Ret (F::*)(Arg, Args...) const); +} // namespace internal + +template struct first_arg { + using type = decltype(internal::first_arg_helper( + &cpp::remove_reference::type::operator())); +}; +template using first_arg_t = typename first_arg::type; + } // namespace rpc } // namespace __llvm_libc diff --git a/libc/test/integration/startup/gpu/rpc_test.cpp b/libc/test/integration/startup/gpu/rpc_test.cpp --- a/libc/test/integration/startup/gpu/rpc_test.cpp +++ b/libc/test/integration/startup/gpu/rpc_test.cpp @@ -18,13 +18,8 @@ uint64_t cnt = 0; for (uint32_t i = 0; i < num_additions; ++i) { rpc::Client::Port port = rpc::client.open(); - port.send_and_recv( - [=](rpc::Buffer *buffer) { - reinterpret_cast(buffer->data)[0] = cnt; - }, - [&](rpc::Buffer *buffer) { - cnt = reinterpret_cast(buffer->data)[0]; - }); + port.send_and_recv([=](uint64_t *buffer) { *buffer = cnt; }, + [&](uint64_t *buffer) { cnt = *buffer; }); port.close(); } ASSERT_TRUE(cnt == num_additions && "Incorrect sum"); @@ -33,7 +28,7 @@ // Test to ensure that the RPC mechanism doesn't hang on divergence. static void test_noop(uint8_t data) { rpc::Client::Port port = rpc::client.open(); - port.send([=](rpc::Buffer *buffer) { buffer->data[0] = data; }); + port.send([=](uint8_t *buffer) { *buffer = data; }); port.close(); } diff --git a/libc/utils/gpu/loader/Server.h b/libc/utils/gpu/loader/Server.h --- a/libc/utils/gpu/loader/Server.h +++ b/libc/utils/gpu/loader/Server.h @@ -44,34 +44,29 @@ break; } case rpc::Opcode::EXIT: { - port->recv([](rpc::Buffer *buffer) { - exit(reinterpret_cast(buffer->data)[0]); - }); + port->recv([](int *status) { exit(*status); }); break; } case rpc::Opcode::TEST_INCREMENT: { - port->recv_and_send([](rpc::Buffer *buffer) { - reinterpret_cast(buffer->data)[0] += 1; - }); + port->recv_and_send([](uint64_t *cnt) { *cnt += 1; }); break; } case rpc::Opcode::TEST_INTERFACE: { uint64_t cnt = 0; bool end_with_recv; - port->recv([&](rpc::Buffer *buffer) { end_with_recv = buffer->data[0]; }); - port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); - port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); - port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); - port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); - port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); - port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); - port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); - port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); + port->recv([&](uint64_t *buffer) { end_with_recv = *buffer; }); + port->recv([&](uint64_t *buffer) { cnt = *buffer; }); + port->send([&](uint64_t *buffer) { *buffer = cnt = cnt + 1; }); + port->recv([&](uint64_t *buffer) { cnt = *buffer; }); + port->send([&](uint64_t *buffer) { *buffer = cnt = cnt + 1; }); + port->recv([&](uint64_t *buffer) { cnt = *buffer; }); + port->recv([&](uint64_t *buffer) { cnt = *buffer; }); + port->send([&](uint64_t *buffer) { *buffer = cnt = cnt + 1; }); + port->send([&](uint64_t *buffer) { *buffer = cnt = cnt + 1; }); if (end_with_recv) - port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; }); + port->recv([&](uint64_t *buffer) { cnt = *buffer; }); else - port->send( - [&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; }); + port->send([&](uint64_t *buffer) { *buffer = cnt = cnt + 1; }); break; } case rpc::Opcode::TEST_STREAM: {