diff --git a/libc/utils/gpu/server/rpc_server.h b/libc/utils/gpu/server/rpc_server.h --- a/libc/utils/gpu/server/rpc_server.h +++ b/libc/utils/gpu/server/rpc_server.h @@ -97,6 +97,12 @@ /// Returns the size of the client in bytes to be used for a memory copy. uint64_t rpc_get_client_size(); +/// Use the \p port to send a buffer using the \p callback. +void rpc_send(rpc_port_t port, rpc_port_callback_ty callback, void *data); + +/// Use the \p port to recieve a buffer using the \p callback. +void rpc_recv(rpc_port_t port, rpc_port_callback_ty callback, void *data); + /// Use the \p port to receive and send a buffer using the \p callback. void rpc_recv_and_send(rpc_port_t port, rpc_port_callback_ty callback, void *data); diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp --- a/libc/utils/gpu/server/rpc_server.cpp +++ b/libc/utils/gpu/server/rpc_server.cpp @@ -320,25 +320,50 @@ uint64_t rpc_get_client_size() { return sizeof(rpc::Client); } +using ServerPort = std::variant::Port *, rpc::Server<32>::Port *, + rpc::Server<64>::Port *>; + +ServerPort getPort(rpc_port_t ref) { + if (ref.lane_size == 1) + return reinterpret_cast::Port *>(ref.handle); + else if (ref.lane_size == 32) + return reinterpret_cast::Port *>(ref.handle); + else if (ref.lane_size == 64) + return reinterpret_cast::Port *>(ref.handle); + else + __builtin_unreachable(); +} + +void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { + auto port = getPort(ref); + std::visit( + [=](auto &port) { + port->send([=](rpc::Buffer *buffer) { + callback(reinterpret_cast(buffer), data); + }); + }, + port); +} + +void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { + auto port = getPort(ref); + std::visit( + [=](auto &port) { + port->recv([=](rpc::Buffer *buffer) { + callback(reinterpret_cast(buffer), data); + }); + }, + port); +} + void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { - if (ref.lane_size == 1) { - rpc::Server<1>::Port *port = - reinterpret_cast::Port *>(ref.handle); - port->recv_and_send([=](rpc::Buffer *buffer) { - callback(reinterpret_cast(buffer), data); - }); - } else if (ref.lane_size == 32) { - rpc::Server<32>::Port *port = - reinterpret_cast::Port *>(ref.handle); - port->recv_and_send([=](rpc::Buffer *buffer) { - callback(reinterpret_cast(buffer), data); - }); - } else if (ref.lane_size == 64) { - rpc::Server<64>::Port *port = - reinterpret_cast::Port *>(ref.handle); - port->recv_and_send([=](rpc::Buffer *buffer) { - callback(reinterpret_cast(buffer), data); - }); - } + auto port = getPort(ref); + std::visit( + [=](auto &port) { + port->recv_and_send([=](rpc::Buffer *buffer) { + callback(reinterpret_cast(buffer), data); + }); + }, + port); }