diff --git a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp --- a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp @@ -520,9 +520,9 @@ } /// Wait until the signal gets a zero value. - Error wait(const uint64_t ActiveTimeout = 0, - RPCHandleTy *RPCHandle = nullptr) const { - if (ActiveTimeout && !RPCHandle) { + Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr, + GenericDeviceTy *Device = nullptr) const { + if (ActiveTimeout && !RPCServer) { hsa_signal_value_t Got = 1; Got = hsa_signal_wait_scacquire(Signal, HSA_SIGNAL_CONDITION_EQ, 0, ActiveTimeout, HSA_WAIT_STATE_ACTIVE); @@ -531,12 +531,12 @@ } // If there is an RPC device attached to this stream we run it as a server. - uint64_t Timeout = RPCHandle ? 8192 : UINT64_MAX; - auto WaitState = RPCHandle ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED; + uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX; + auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED; while (hsa_signal_wait_scacquire(Signal, HSA_SIGNAL_CONDITION_EQ, 0, Timeout, WaitState) != 0) { - if (RPCHandle) - if (auto Err = RPCHandle->runServer()) + if (RPCServer && Device) + if (auto Err = RPCServer->runServer(*Device)) return Err; } return Plugin::success(); @@ -888,6 +888,9 @@ /// The manager of signals to reuse signals. AMDGPUSignalManagerTy &SignalManager; + /// A reference to the associated device. + GenericDeviceTy &Device; + /// Array of stream slots. Use std::deque because it can dynamically grow /// without invalidating the already inserted elements. For instance, the /// std::vector may invalidate the elements by reallocating the internal @@ -907,7 +910,7 @@ /// A pointer associated with an RPC server running on the given device. If /// RPC is not being used this will be a null pointer. Otherwise, this /// indicates that an RPC server is expected to be run on this stream. - RPCHandleTy *RPCHandle; + RPCServerTy *RPCServer; /// Mutex to protect stream's management. mutable std::mutex Mutex; @@ -1064,8 +1067,8 @@ /// Deinitialize the stream's signals. Error deinit() { return Plugin::success(); } - /// Attach an RPC handle to this stream. - void setRPCHandle(RPCHandleTy *Handle) { RPCHandle = Handle; } + /// Attach an RPC server to this stream. + void setRPCServer(RPCServerTy *Server) { RPCServer = Server; } /// Push a asynchronous kernel to the stream. The kernel arguments must be /// placed in a special allocation for kernel args and must keep alive until @@ -1281,8 +1284,8 @@ return Plugin::success(); // Wait until all previous operations on the stream have completed. - if (auto Err = - Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, RPCHandle)) + if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, + RPCServer, &Device)) return Err; // Reset the stream and perform all pending post actions. @@ -2529,9 +2532,9 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device) : Agent(Device.getAgent()), Queue(Device.getNextQueue()), - SignalManager(Device.getSignalManager()), + SignalManager(Device.getSignalManager()), Device(Device), // Initialize the std::deque with some empty positions. - Slots(32), NextSlot(0), SyncCycle(0), RPCHandle(nullptr), + Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr), StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()) {} /// Class implementing the AMDGPU-specific functionalities of the global @@ -2866,8 +2869,8 @@ AMDGPUStreamTy &Stream = AMDGPUDevice.getStream(AsyncInfoWrapper); // If this kernel requires an RPC server we attach its pointer to the stream. - if (GenericDevice.getRPCHandle()) - Stream.setRPCHandle(GenericDevice.getRPCHandle()); + if (GenericDevice.getRPCServer()) + Stream.setRPCServer(GenericDevice.getRPCServer()); // Push the kernel launch into the stream. return Stream.pushKernelLaunch(*this, AllArgs, NumThreads, NumBlocks, diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/CMakeLists.txt b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/CMakeLists.txt --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/CMakeLists.txt +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/CMakeLists.txt @@ -70,7 +70,6 @@ find_library(llvmlibc_rpc_server NAMES llvmlibc_rpc_server PATHS ${LIBOMPTARGET_LLVM_LIBRARY_DIR} NO_DEFAULT_PATH) if(llvmlibc_rpc_server) - message(WARNING ${llvmlibc_rpc_server}) target_link_libraries(PluginInterface PRIVATE llvmlibc_rpc_server) target_compile_definitions(PluginInterface PRIVATE LIBOMPTARGET_RPC_SUPPORT) endif() diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h @@ -762,7 +762,7 @@ } /// Get the RPC server running on this device. - RPCHandleTy *getRPCHandle() const { return RPCHandle; } + RPCServerTy *getRPCServer() const { return RPCServer; } private: /// Register offload entry for global variable. @@ -857,7 +857,7 @@ /// A pointer to an RPC server instance attached to this device if present. /// This is used to run the RPC server during task synchronization. - RPCHandleTy *RPCHandle; + RPCServerTy *RPCServer; #ifdef OMPT_SUPPORT /// OMPT callback functions diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp @@ -401,7 +401,7 @@ OMPX_InitialNumEvents("LIBOMPTARGET_NUM_INITIAL_EVENTS", 32), DeviceId(DeviceId), GridValues(OMPGridValues), PeerAccesses(NumDevices, PeerAccessState::PENDING), PeerAccessesLock(), - PinnedAllocs(*this), RPCHandle(nullptr) { + PinnedAllocs(*this), RPCServer(nullptr) { #ifdef OMPT_SUPPORT OmptInitialized.store(false); // Bind the callbacks to this device's member functions @@ -483,8 +483,8 @@ if (RecordReplay.isRecordingOrReplaying()) RecordReplay.deinit(); - if (RPCHandle) - if (auto Err = RPCHandle->deinitDevice()) + if (RPCServer) + if (auto Err = RPCServer->deinitDevice(*this)) return Err; #ifdef OMPT_SUPPORT @@ -599,10 +599,7 @@ if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image)) return Err; - auto DeviceOrErr = Server.getDevice(*this); - if (!DeviceOrErr) - return DeviceOrErr.takeError(); - RPCHandle = *DeviceOrErr; + RPCServer = &Server; DP("Running an RPC server on device %d\n", getDeviceId()); return Plugin::success(); } diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h @@ -32,21 +32,6 @@ /// these routines will perform no action. struct RPCServerTy { public: - /// A wrapper around a single instance of the RPC server for a given device. - /// This is provided to simplify ownership of the underlying device. - struct RPCHandleTy { - RPCHandleTy(RPCServerTy &Server, plugin::GenericDeviceTy &Device) - : Server(Server), Device(Device) {} - - llvm::Error runServer() { return Server.runServer(Device); } - - llvm::Error deinitDevice() { return Server.deinitDevice(Device); } - - private: - RPCServerTy &Server; - plugin::GenericDeviceTy &Device; - }; - RPCServerTy(uint32_t NumDevices); /// Check if this device image is using an RPC server. This checks for the @@ -63,9 +48,6 @@ plugin::GenericGlobalHandlerTy &Handler, plugin::DeviceImageTy &Image); - /// Gets a reference to this server for a specific device. - llvm::Expected getDevice(plugin::GenericDeviceTy &Device); - /// Runs the RPC server associated with the \p Device until the pending work /// is cleared. llvm::Error runServer(plugin::GenericDeviceTy &Device); @@ -75,13 +57,8 @@ llvm::Error deinitDevice(plugin::GenericDeviceTy &Device); ~RPCServerTy(); - -private: - llvm::SmallVector> Handles; }; -using RPCHandleTy = RPCServerTy::RPCHandleTy; - } // namespace llvm::omp::target #endif diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp @@ -28,7 +28,6 @@ // If this fails then something is catastrophically wrong, just exit. if (rpc_status_t Err = rpc_init(NumDevices)) FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err); - Handles.resize(NumDevices); #endif } @@ -118,28 +117,10 @@ if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer, rpc_get_client_size(), nullptr)) return Err; - - Handles[DeviceId] = std::make_unique(*this, Device); #endif return Error::success(); } -llvm::Expected -RPCServerTy::getDevice(plugin::GenericDeviceTy &Device) { -#ifdef LIBOMPTARGET_RPC_SUPPORT - uint32_t DeviceId = Device.getDeviceId(); - if (!Handles[DeviceId] || !rpc_get_buffer(DeviceId) || - !rpc_get_client_buffer(DeviceId)) - return plugin::Plugin::error( - "Attempt to get an RPC device while not initialized"); - - return Handles[DeviceId].get(); -#else - return plugin::Plugin::error( - "Attempt to get an RPC device while not available"); -#endif -} - Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) { #ifdef LIBOMPTARGET_RPC_SUPPORT if (rpc_status_t Err = rpc_handle_server(Device.getDeviceId())) diff --git a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp @@ -474,12 +474,12 @@ CUresult Res; // If we have an RPC server running on this device we will continuously // query it for work rather than blocking. - if (!getRPCHandle()) { + if (!getRPCServer()) { Res = cuStreamSynchronize(Stream); } else { do { Res = cuStreamQuery(Stream); - if (auto Err = getRPCHandle()->runServer()) + if (auto Err = getRPCServer()->runServer(*this)) return Err; } while (Res == CUDA_ERROR_NOT_READY); }