diff --git a/openmp/libomptarget/cmake/Modules/LibomptargetGetDependencies.cmake b/openmp/libomptarget/cmake/Modules/LibomptargetGetDependencies.cmake --- a/openmp/libomptarget/cmake/Modules/LibomptargetGetDependencies.cmake +++ b/openmp/libomptarget/cmake/Modules/LibomptargetGetDependencies.cmake @@ -240,3 +240,21 @@ endif() set(OPENMP_PTHREAD_LIB ${LLVM_PTHREAD_LIB}) + +################################################################################ +# Looking for MPI... +################################################################################ +find_package(MPI QUIET) + +set(LIBOMPTARGET_DEP_MPI_FOUND ${MPI_CXX_FOUND}) +set(LIBOMPTARGET_DEP_MPI_LIBRARIES ${MPI_CXX_LIBRARIES}) +set(LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS ${MPI_CXX_INCLUDE_DIRS}) +set(LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS ${MPI_CXX_COMPILE_FLAGS}) +set(LIBOMPTARGET_DEP_MPI_LINK_FLAGS ${MPI_CXX_LINK_FLAGS}) + +mark_as_advanced( + LIBOMPTARGET_DEP_MPI_FOUND + LIBOMPTARGET_DEP_MPI_LIBRARIES + LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS + LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS + LIBOMPTARGET_DEP_MPI_LINK_FLAGS) diff --git a/openmp/libomptarget/include/omptargetplugin.h b/openmp/libomptarget/include/omptargetplugin.h --- a/openmp/libomptarget/include/omptargetplugin.h +++ b/openmp/libomptarget/include/omptargetplugin.h @@ -192,6 +192,14 @@ int32_t __tgt_rtl_init_device_info(int32_t ID, __tgt_device_info *DeviceInfoPtr, const char **ErrStr); +// Check if the current code is being executed inside the device itself. If that +// is the case, the device main function must be executed. +int32_t __tgt_rtl_is_inside_device(); + +// Run device main if supported. The program will immediately exit after a +// successful execution. +void __tgt_rtl_run_device_main(__tgt_bin_desc *desc); + #ifdef __cplusplus } #endif diff --git a/openmp/libomptarget/include/rtl.h b/openmp/libomptarget/include/rtl.h --- a/openmp/libomptarget/include/rtl.h +++ b/openmp/libomptarget/include/rtl.h @@ -75,6 +75,8 @@ typedef int32_t(init_async_info_ty)(int32_t, __tgt_async_info **); typedef int64_t(init_device_into_ty)(int64_t, __tgt_device_info *, const char **); + typedef int32_t(is_inside_device_ty)(); + typedef void(run_device_main_ty)(__tgt_bin_desc *); int32_t Idx = -1; // RTL index, index is the number of devices // of other RTLs that were registered before, @@ -125,6 +127,8 @@ init_async_info_ty *init_async_info = nullptr; init_device_into_ty *init_device_info = nullptr; release_async_info_ty *release_async_info = nullptr; + run_device_main_ty *run_device_main = nullptr; + is_inside_device_ty *is_inside_device = nullptr; // Are there images associated with this RTL. bool IsUsed = false; @@ -142,6 +146,9 @@ // List of the detected runtime libraries. std::list AllRTLs; + // List of runtime devices with a main function. + llvm::SmallVector ExecutableRTLs; + // Array of pointers to the detected runtime libraries that have compatible // binaries. llvm::SmallVector UsedRTLs; diff --git a/openmp/libomptarget/plugins/CMakeLists.txt b/openmp/libomptarget/plugins/CMakeLists.txt --- a/openmp/libomptarget/plugins/CMakeLists.txt +++ b/openmp/libomptarget/plugins/CMakeLists.txt @@ -85,8 +85,8 @@ add_subdirectory(ve) add_subdirectory(x86_64) add_subdirectory(remote) +add_subdirectory(mpi) # Make sure the parent scope can see the plugins that will be created. set(LIBOMPTARGET_SYSTEM_TARGETS "${LIBOMPTARGET_SYSTEM_TARGETS}" PARENT_SCOPE) set(LIBOMPTARGET_TESTED_PLUGINS "${LIBOMPTARGET_TESTED_PLUGINS}" PARENT_SCOPE) - diff --git a/openmp/libomptarget/plugins/mpi/CMakeLists.txt b/openmp/libomptarget/plugins/mpi/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/mpi/CMakeLists.txt @@ -0,0 +1,72 @@ +##===----------------------------------------------------------------------===## +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +##===----------------------------------------------------------------------===## +# +# Build a plugin for a MPI machine if available. +# +##===----------------------------------------------------------------------===## +if (NOT(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(ppc64le)$" AND CMAKE_SYSTEM_NAME MATCHES "Linux")) + libomptarget_say("Not building MPI offloading plugin: only support MPI in Linux x86_64 or ppc64le hosts.") + return() +elseif (NOT LIBOMPTARGET_DEP_LIBFFI_FOUND) + libomptarget_say("Not building MPI offloading plugin: libffi dependency not found.") + return() +elseif(NOT LIBOMPTARGET_DEP_MPI_FOUND) + libomptarget_say("Not building MPI offloading plugin: MPI not found in system.") + return() +endif() + +libomptarget_say("Building MPI offloading plugin.") + +# Source defines. +# ============================================================================== +# Define the suffix for the runtime messaging dumps. +add_definitions(-DTARGET_NAME=MPI) +# Define macro with the ELF ID for this target. +add_definitions("-DTARGET_ELF_ID=62") + +add_llvm_library(omptarget.rtl.mpi SHARED + src/EventSystem.cpp + src/MPIManager.cpp + src/rtl.cpp + + ADDITIONAL_HEADER_DIRS + ${LIBOMPTARGET_INCLUDE_DIR} + ${LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS} + ${LIBOMPTARGET_DEP_LIBFFI_INCLUDE_DIRS} + + LINK_LIBS + PRIVATE + ${LIBOMPTARGET_DEP_MPI_LIBRARIES} + ${LIBOMPTARGET_DEP_LIBFFI_LIBRARIES} + elf_common + MemoryManager + dl + ${OPENMP_PTHREAD_LIB} + "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/../exports" + ${LIBOMPTARGET_DEP_MPI_LINK_FLAGS} + + NO_INSTALL_RPATH +) + +# Add include directories +target_include_directories(omptarget.rtl.mpi + PRIVATE include ${LIBOMPTARGET_INCLUDE_DIR} +) + +# Install plugin under the lib destination folder. +install(TARGETS omptarget.rtl.mpi LIBRARY DESTINATION "${OPENMP_INSTALL_LIBDIR}") +set_target_properties(omptarget.rtl.mpi PROPERTIES + INSTALL_RPATH "$ORIGIN" BUILD_RPATH "$ORIGIN:${CMAKE_CURRENT_BINARY_DIR}/..") + +if(LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS) + set_target_properties(omptarget.rtl.mpi PROPERTIES + COMPILE_FLAGS "${LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS}") +endif() + +# Report to the parent scope that we are building a plugin for MPI. +set(LIBOMPTARGET_SYSTEM_TARGETS "${LIBOMPTARGET_SYSTEM_TARGETS} x86_64-pc-linux-gnu-mpi" PARENT_SCOPE) diff --git a/openmp/libomptarget/plugins/mpi/include/Common.h b/openmp/libomptarget/plugins/mpi/include/Common.h new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/mpi/include/Common.h @@ -0,0 +1,29 @@ +//===------------- coroutines.h - Common definitions ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains common functionalities for MPI plugin +// +//===----------------------------------------------------------------------===// + +#include + +// Debug utilities definitions +// =========================================================================== +#ifndef TARGET_NAME +#define TARGET_NAME MPI +#endif +#define DEBUG_PREFIX "Target " GETNAME(TARGET_NAME) " RTL" +#include "Debug.h" + +#define CHECK(expr, msg, ...) \ + if (!(expr)) { \ + REPORT(msg, ##__VA_ARGS__); \ + return false; \ + } + +#define assertm(expr, msg) assert(((void)msg, expr)); diff --git a/openmp/libomptarget/plugins/mpi/include/Coroutines.h b/openmp/libomptarget/plugins/mpi/include/Coroutines.h new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/mpi/include/Coroutines.h @@ -0,0 +1,72 @@ +//===------- coroutines.h - C++17 coroutines implementation -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions to implement coroutines in C++17. +// +//===----------------------------------------------------------------------===// + +#ifndef _OMPTARGET_OMPCLUSTER_COROUTINES_H_ +#define _OMPTARGET_OMPCLUSTER_COROUTINES_H_ + +// Support macros +#define CONCAT_IMPL(x, y) x##y +#define MACRO_CONCAT(x, y) CONCAT_IMPL(x, y) +#define GET_MACRO(_1, name, ...) name + +// Coroutine control macros +// NOTE: Please, leave the macros definition in the same line or split the +// macros using `\`. +// ============================================================================= +// Begin a coroutine function. +#define CO_BEGIN() \ + do { \ + if (ResumeLocation != nullptr) { \ + goto *ResumeLocation; \ + } \ + } while (false); +// End a coroutine function, maybe returning a final value. +#define CO_RETURN_VOID() \ + do { \ + ResumeLocation = &&MACRO_CONCAT(COROUTINE_YIELD_, __LINE__); \ + MACRO_CONCAT(COROUTINE_YIELD_, __LINE__) :; \ + return; \ + } while (false); +#define CO_RETURN_VALUE(value) \ + do { \ + ResumeLocation = &&MACRO_CONCAT(COROUTINE_YIELD_, __LINE__); \ + MACRO_CONCAT(COROUTINE_YIELD_, __LINE__) :; \ + return value; \ + } while (false); +#define CO_RETURN(...) \ + GET_MACRO(__VA_ARGS__, CO_RETURN_VALUE, CO_RETURN_VOID)(__VA_ARGS__) +// Halts the coroutine execution. The next call will resume the execution +#define CO_YIELD_VOID() \ + do { \ + ResumeLocation = &&MACRO_CONCAT(COROUTINE_YIELD_, __LINE__); \ + return; \ + MACRO_CONCAT(COROUTINE_YIELD_, __LINE__) :; \ + } while (false); +#define CO_YIELD_VALUE(value) \ + do { \ + ResumeLocation = &&MACRO_CONCAT(COROUTINE_YIELD_, __LINE__); \ + return value; \ + MACRO_CONCAT(COROUTINE_YIELD_, __LINE__) :; \ + } while (false); +#define CO_YIELD(...) \ + GET_MACRO(__VA_ARGS__, CO_YIELD_VALUE, CO_YIELD_VOID)(__VA_ARGS__) + +// TODO: Refactor the event system to use this interface. +// // Coroutine base structure. +// template struct Coroutine { +// using LabelPointer = void *; +// LabelPointer ResumeLocation = nullptr; + +// virtual ReturnType operator()() = 0; +// }; + +#endif // _OMPTARGET_OMPCLUSTER_COROUTINES_H_ diff --git a/openmp/libomptarget/plugins/mpi/include/EventSystem.h b/openmp/libomptarget/plugins/mpi/include/EventSystem.h new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/mpi/include/EventSystem.h @@ -0,0 +1,597 @@ +//===------- event_system.h - Concurrent MPI communication ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the declarations of the MPI Event System used by the MPI +// target. +// +//===----------------------------------------------------------------------===// + +#ifndef _OMPTARGET_OMPCLUSTER_EVENT_SYSTEM_H_ +#define _OMPTARGET_OMPCLUSTER_EVENT_SYSTEM_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#define MPICH_SKIP_MPICXX +#include + +#include "llvm/ADT/SmallVector.h" + +// External forward declarations. +// ============================================================================= +class MPIManagerTy; +struct __tgt_target_table; + +// Internal forward declarations and type aliases. +// ============================================================================= +class BaseEventTy; +enum class EventSystemStateTy; + +/// Automaticaly managed event pointer. +/// +/// \note: Every event must always be accessed/stored in a shared_ptr structure. +/// This allows for automatic memory management among the many threads of the +/// libomptarget runtime. +using EventPtr = std::shared_ptr; + +// Event Types +// ============================================================================= +/// The event location. +/// +/// Enumerates whether an event is executing at its two possible locations: +/// its origin or its destination. +enum class EventLocationTy : bool { ORIG = 0, DEST }; + +/// The event type (type of action it will performed). +/// +/// Enumerates the available events. Each enum item should be accompanied by an +/// event class derived from BaseEvent. All the events are executed at a remote +/// MPI process target by the event. +enum class EventTypeTy : int { + // Memory management. + ALLOC, // Allocates a buffer at the remote process. + DELETE, // Deletes a buffer at the remote process. + + // Data movement. + RETRIEVE, // Receives a buffer data from a remote process. + SUBMIT, // Sends a buffer data to a remote process. + EXCHANGE, // Exchange a buffer between two remote processes. + + // Target region execution. + EXECUTE, // Executes a target region at the remote process. + + // Local event used to wait on other events. + SYNC, + + // Internal event system commands. + EXIT // Stops the event system execution at the remote process. +}; + +/// The event execution state. +/// +/// Enumerates the event states during its lifecycle through the event system. +/// New states should be added to the enum in the order they happen at the event +/// system. +enum class EventStateTy : int { + CREATED = 0, // Event was only create but it was not executed yet. + EXECUTING, // Event is currently being executed in the background and it + // is registering new MPI requests. + WAITING, // Event was executed and is now waiting on the MPI requests + // complete. + FAILED, // Event failed during execution + FINISHED // The event and its MPI requests are completed. +}; + +/// EventLocation to string conversion. +/// +/// \returns the string representation of \p location. +const char *toString(EventLocationTy location); + +/// EventType to string conversion. +/// +/// \returns the string representation of \p type. +const char *toString(EventTypeTy type); + +/// EventState to string conversion. +/// +/// \returns the string representation of \p state. +const char *toString(EventStateTy state); + +// Events +// ============================================================================= + +/// The base event of all event types. +/// +/// This class contains both the common data stored and common procedures +/// executed at all events. New events that derive from this class must comply +/// with the following: +/// - Declare the new EventType item; +/// - Name the derived class as the concatenation of the new EventType name and +/// the "Event" word. E.g.: ALLOC event -> class AllocEvent; +/// - Implement both pure virtual functions: +/// - #runOrigin; +/// - #runDestination; +/// - Implement two constructors (one for each EventLocation) with this +/// prototype: +/// +/// Event(int MPITag, MPI_Comm TargetComm, int OrigRank, int +/// DestRank,...); +/// +class BaseEventTy { +public: + /// The only (non-child) class that can access the Event protected members. + friend class EventSystemTy; + + // Event definitions. + /// Location that the event is being executed: origin or destination ranks. + const EventLocationTy EventLocation; + /// Event type that represents its actions. + const EventTypeTy EventType; + + // MPI definitions. + /// MPI communicator to be used by the event. + const MPI_Comm TargetComm; + /// MPI tag that must be used on every MPI communication of the event. + const int MPITag; + /// Rank of the process that created the event. + const int OrigRank; + /// Rank of the process that was target by the event. + const int DestRank; + +protected: + using LabelPointer = void *; + /// Event coroutine state. + LabelPointer ResumeLocation = nullptr; + +private: + /// MPI non-blocking requests to be synchronized during event execution. + llvm::SmallVector PendingRequests; + + /// The event execution state. + std::atomic EventState{EventStateTy::CREATED}; + + /// Call-guard for the #progress function. + /// + /// This atomic ensures only one thread executes the #progress code at a time. + std::atomic ProgressGuard{false}; + + /// Parameters used to create the event at the destination process. + /// + /// \note this array is needed so non-blocking MPI messages can be used to + /// create the event. + uint32_t InitialRequestInfo[2] = {0, 0}; + +public: + BaseEventTy &operator=(BaseEventTy &) = delete; + BaseEventTy(BaseEventTy &) = delete; + + /// Advance the progress of the event. + /// + /// \note calling this function once does n;;;;;ot guarantee that the event is + /// completely executed. One must call this function until #isDone returns + /// true. + void progress(); + + /// Check if the event is completed. + /// + /// \return true if the event is completed. + bool isDone() const; + + /// Wait for the event to be completed. + /// + /// Waits for the completion of the event on both the local and remote + /// processes. This function will choose between the blocking and tasking + /// implementations depending if the current code is inside a task or not. + /// + /// \note this function is almost equivalent to calling #runStage while + /// #isDone is returning false. + void wait(); + + /// Get the current event execution state. + /// + /// \returns The current EventState. + EventStateTy getEventState() const; + +protected: + /// BaseEvent constructor. + BaseEventTy(EventLocationTy EventLocation, EventTypeTy EventType, int MPITag, + MPI_Comm TargetComm, int OrigRank, int DestRank); + + /// BaseEvent default destructor. + virtual ~BaseEventTy() = default; + + /// Push a new MPI request to the request array. + /// + /// \returns a reference to the next available MPI request at + /// #pending_requests. + MPI_Request *getNextRequest(); + + /// Test if all pending requests have finished. + /// + /// \return true if all pending MPI requests are completed, false otherwise. + bool checkPendingRequests(); + +private: + /// Sends a new event notification to the destination. + void notifyNewEvent(); + + /// Calls #runOrigin or #runDestination coroutine. + /// + /// \returns true if the event run coroutine has finished, false otherwise. + bool runCoroutine(); + + // Event coroutines + /// Executes the origin side of the event locally. + virtual bool runOrigin() = 0; + /// Executes the destination side of the event locally. + virtual bool runDestination() = 0; +}; + +/// Allocates a buffer at a remote process. +class AllocEventTy final : public BaseEventTy { +private: + /// Size of the buffer to be allocated. + int64_t Size = 0; + /// Pointer to variable to be filled with the address of the allocated buffer. + void **AllocatedAddressPtr = nullptr; + // Allocated address to be send back. + void *DestAddress = 0; + +public: + /// Origin constructor. + AllocEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank, + int64_t Size, void **AllocatedAddress); + + /// Destination constructor. + AllocEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank); + +private: + /// Sends the size and receives the allocated address. + bool runOrigin() override; + /// Receives the size, allocating the data and sending its address. + bool runDestination() override; +}; + +/// Frees a buffer at a remote process. +class DeleteEventTy final : public BaseEventTy { +private: + /// Address of the buffer to be freed. + void *TargetAddress = 0; + +public: + /// Origin constructor. + DeleteEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank, + void *TargetAddress); + + /// Destination constructor. + DeleteEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank); + +private: + /// Sends the address and waits for a notification of the data deletion. + bool runOrigin() override; + /// Receives the address, frees it and send a completion notification. + bool runDestination() override; +}; + +/// Retrieves a buffer from a remote process. +class RetrieveEventTy final : public BaseEventTy { +private: + /// Address of the origin's buffer to be filled with the destination's data. + void *OrigPtr = nullptr; + /// Address of the destination's buffer to be retrieved. + const void *DestPtr = nullptr; + /// Size of both the origin's and destination's buffers. + int64_t Size = 0; + +public: + /// Origin constructor. + RetrieveEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank, + void *OrigPtr, const void *DestPtr, int64_t Size); + + /// Destination constructor. + RetrieveEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank); + +private: + /// Sends the buffer info and retries its data. + bool runOrigin() override; + /// Receives the buffer info and sends its data. + bool runDestination() override; +}; + +/// Send a buffer to a remote process. +class SubmitEventTy final : public BaseEventTy { +private: + /// Address of the origin's buffer to be submitted. + const void *OrigPtr = nullptr; + /// Address of the destination's buffer to be filled with the origin's data. + void *DestPtr = nullptr; + /// Size of both the origin's and destination's buffers. + int64_t Size = 0; + +public: + /// Origin constructor. + // TODO: Change to dest first then origin (target then host) + SubmitEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank, + const void *OrigPtr, void *DestPtr, int64_t Size); + + /// Destination constructor. + SubmitEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank); + +private: + /// Sends the buffer info and then the buffer itself. + bool runOrigin() override; + /// Receives the buffer info and then the buffer itself. + bool runDestination() override; + + friend class PackedSubmitEvent; +}; + +/// Exchange a buffer between two remote processes. +class ExchangeEventTy final : public BaseEventTy { +private: + /// MPI rank of the data destination process. + int DataDestRank = 0; + /// Address of the data at the data source process. + const void *SrcPtr = nullptr; + /// Address of the data at the data destination process. + void *DstPtr = nullptr; + /// Size of both the data source's and data destination's buffers. + int64_t Size = 0; + /// Pointer to the remote submit event created at the remote source location. + EventPtr RemoteSubmitEvent = nullptr; + +public: + /// Origin constructor. + ExchangeEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank, + int DataDestRank, const void *SrcPtr, void *DstPtr, + int64_t Size); + + /// Destination constructor. + ExchangeEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank); + +private: + /// Sends the buffer info. + bool runOrigin() override; + /// Receives the buffer info and start a SubmitEvent to the dest process. + bool runDestination() override; +}; + +/// Executes a target region at a remote process. +class ExecuteEventTy final : public BaseEventTy { +private: + /// Number of arguments of the target region. + int32_t NumArgs = 0; + /// Arguments of the target region. + llvm::SmallVector Args{}; + /// Index of the target region. + uint32_t TargetEntryIdx = -1; + // Local target table with entry addresses. + __tgt_target_table *TargetTable = nullptr; + +public: + /// Origin constructor. + ExecuteEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank, + int32_t NumArgs, void **Args, uint32_t TargetEntryIdx); + + /// Destination constructor. + ExecuteEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank, + __tgt_target_table *TargetTable); + +private: + /// Sends the target region info and wait for the completion notification. + bool runOrigin() override; + /// Receives the target region info, executes it and sends the notification. + bool runDestination() override; +}; + +/// Local event used to wait on other events. +class SyncEventTy final : public BaseEventTy { +private: + EventPtr TargetEvent; + +public: + /// Destination constructor. + SyncEventTy(EventPtr &TargetEvent); + +private: + /// Does nothing. + bool runOrigin() override; + /// Waits for target_event to complete. + bool runDestination() override; +}; + +/// Notify a remote process to stop its event system. +class ExitEventTy final : public BaseEventTy { +private: + /// Pointer to the event system state. + std::atomic *EventSystemState = nullptr; + +public: + /// Origin constructor. + ExitEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank); + + /// Destination constructor. + ExitEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, int DestRank, + std::atomic *EventSystemState); + +private: + /// Just waits for the completion notification. + bool runOrigin() override; + /// Stops its event system and sends the notification. + bool runDestination() override; +}; + +// Event Queue +// ============================================================================= +/// Event queue for received events. +class EventQueue { +private: + /// Base internal queue. + std::queue Queue; + /// Base queue sync mutex. + std::mutex QueueMtx; + + /// Conditional variables to block popping on an empty queue. + std::condition_variable CanPopCv; + +public: + /// Event Queue default constructor. + EventQueue(); + + /// Gets current queue size. + size_t size(); + + /// Push an event to the queue, resizing it when needed. + void push(EventPtr &Event); + + /// Pops an event from the queue, returning nullptr if the queue is empty. + EventPtr pop(); +}; + +// Event System +// ============================================================================= + +/// MPI tags used in control messages. +/// +/// Special tags values used to send control messages between event systems of +/// different processes. When adding new tags, please summarize the tag usage +/// with a side comment as done below. +enum class ControlTagsTy : int { + EVENT_REQUEST = 0, // Used by event handlers to receive new event requests. + FIRST_EVENT // Tag used by the first event. Must always be placed last. +}; + +/// Event system execution state. +/// +/// Describes the event system state through the program. +enum class EventSystemStateTy { + CREATED, // ES was created but it is not ready to send or receive new + // events. + INITIALIZED, // ES was initialized alongside internal MPI states. It is ready + // to send new events, but not receive them. + RUNNING, // ES is running and ready to receive new events. + EXITED // ES was stopped. +}; + +/// The distributed event system. +class EventSystemTy { +public: + /// The largest MPI tag allowed by its implementation. + static int32_t MPITagMaxValue; + + /// Communicator used by the gate thread. + // TODO: Find a better way to share this with all the events. static is not + // that great. + static MPI_Comm GateThreadComm; + +private: + // MPI definitions. + /// Communicator pool distributed over the events. + llvm::SmallVector EventCommPool{}; + /// Number of process used by the event system. + int WorldSize = -1; + /// The local rank of the current instance. + int LocalRank = -1; + + /// Number of event created by the current instance. + std::atomic EventCounter{0}; + + /// Event queue between the local gate thread and event handlers. + EventQueue ExecEventQueue{}; + EventQueue DataEventQueue{}; + + /// Event System execution state. + std::atomic EventSystemState{}; + + bool IsInitialized = false; + +private: + /// Function executed by the event handler threads. + void runEventHandler(EventQueue &Queue); + + /// Creates a new unique event tag for a new event. + int createNewEventTag(); + + /// Gets a comm for a new event from the comm pool. + MPI_Comm &getNewEventComm(int MPITag); + +public: + /// Creates a local MPI context containing a exclusive comm for the gate + /// thread, and a comm pool to be used internally by the events. It also + /// acquires the local MPI process description. + bool createLocalMPIContext(); + + /// Destroy the local MPI context and all of its comms. + bool destroyLocalMPIContext(); + + EventSystemTy(); + ~EventSystemTy(); + + bool initialize(); + bool deinitialize(); + + /// Creates a new event. + /// + /// Creates a new event of 'EventClass' type targeting the 'DestRank'. The + /// 'args' parameters are additional arguments that may be passed to the + /// EventClass origin constructor. + /// + /// /note: since this is a template function, it must be defined in + /// this header. + template + EventPtr createEvent(int DestRank, ArgsTy &&...Args); + + /// Gate thread procedure. + /// + /// Caller thread will spawn the event handlers, execute the gate logic and + /// wait until the event system receive an Exit event. + void runGateThread(__tgt_target_table *TargetTable); + + /// Get the number of workers available. + /// + /// \return the number of MPI available workers. + int getNumWorkers() const; + + /// Check if we are at the host MPI process. + /// + /// \return true if the current MPI process is the host (rank 0), false + /// otherwise. + int isHead() const; +}; + +template +EventPtr EventSystemTy::createEvent(int DstDeviceID, ArgsTy &&...Args) { + static_assert(std::is_convertible_v, + "Cannot create an event from a class that is not derived from " + "the BaseEvent class."); + using MPITagTy = int; + using RankTy = int; + static_assert(std::is_constructible_v, + "Cannot create an event from the given argument types."); + + // MPI rank 0 is our head node/host. Worker rank starts at 1. + const int DstDeviceRank = DstDeviceID + 1; + + const int EventTag = createNewEventTag(); + auto &EventComm = getNewEventComm(EventTag); + + EventPtr Event = std::make_shared( + EventTag, EventComm, LocalRank, DstDeviceRank, + std::forward(Args)...); + + return Event; +} + +#endif // _OMPTARGET_OMPCLUSTER_EVENT_SYSTEM_H_ diff --git a/openmp/libomptarget/plugins/mpi/include/MPIManager.h b/openmp/libomptarget/plugins/mpi/include/MPIManager.h new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/mpi/include/MPIManager.h @@ -0,0 +1,222 @@ +//===---------- RTLs/mpi/src/rtl.h - MPI RTL Definition - C++ -----------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Declarations for the MPI RTL plugin. +// +//===----------------------------------------------------------------------===// + +#ifndef _OMPTARGET_OMPCLUSTER_MPI_MANAGER_H_ +#define _OMPTARGET_OMPCLUSTER_MPI_MANAGER_H_ + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/DynamicLibrary.h" + +#include "Common.h" +#include "EventSystem.h" + +#include "MemoryManager.h" +#include "omptarget.h" + +using DynamicLibrary = llvm::sys::DynamicLibrary; +template using SmallVector = llvm::SmallVector; + +// MPI Manager. +// ============================================================================ +/// Class containing all the device information. +class MPIManagerTy { + /// Array of Dynamic libraries loaded for this target. + struct DynLibTy { + std::string FileName; + std::unique_ptr DynLib; + }; + + /// Keep entries table per device. + struct FuncOrGblEntryTy { + __tgt_target_table Table; + SmallVector<__tgt_offload_entry> Entries; + }; + + // Memory Allocator. + // ============================================================================ + /// A class responsible for interacting with device native runtime library to + /// allocate and free memory. + class MPIDeviceAllocatorTy : public DeviceAllocatorTy { + const int DeviceId; + EventSystemTy &EventSystem; + + public: + MPIDeviceAllocatorTy(int DeviceId, EventSystemTy &EventSystem) + : DeviceId(DeviceId), EventSystem(EventSystem) {} + + void *allocate(size_t Size, void *HostPtr, + TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) override; + + int free(void *TargetPtr, + TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) override; + }; + + // Distributed event system responsible for hiding communications between + // nodes. + EventSystemTy EventSystem; + + // Memory manager + // =========================================================================== + // Whether use memory manager + bool UseMemoryManager = true; + // A vector of device allocators + SmallVector ProcessAllocators{}; + // A vector of memory managers. Since the memory manager is non-copyable and + // non-movable, we wrap them into std::unique_ptr. + SmallVector> MemoryManagers{}; + + // Dynamic libraries and and function list + // =========================================================================== + std::list DynLibs{}; + SmallVector> FuncGblEntries{}; + + // TODO: Use atomic and query for it when setting. + bool IsInitialized = false; + + // Should be non-copyable and non-movable. +public: + MPIManagerTy(MPIManagerTy &) = delete; + MPIManagerTy &operator=(MPIManagerTy &) = delete; + MPIManagerTy(MPIManagerTy &&) = delete; + MPIManagerTy &operator=(MPIManagerTy &&) = delete; + + // De/initialization functions. + // =========================================================================== +public: + MPIManagerTy(){}; + ~MPIManagerTy(); + + bool initialize(); + bool deinitialize(); + + // Dynamic library loading and handling. + // =========================================================================== +public: + // Load a binary into a device context. + __tgt_target_table *loadBinary(const int DeviceId, + const __tgt_device_image *Image); + + // Check whether is given binary valid for the plugin. + int32_t isValidBinary(__tgt_device_image *Image) const; + +private: + // Record entry point associated with a device. + void createOffloadTable(int32_t DeviceId, + SmallVector<__tgt_offload_entry> &&Entries); + + // Return true if the entry is associated with the device. + bool findOffloadEntry(int32_t DeviceId, void *Addr); + + // Return the pointer to the target entries table. + __tgt_target_table *getOffloadEntriesTable(int32_t DeviceId); + + // Return the pointer to the target entries table. + __tgt_target_table *getOffloadEntriesTableOnWorker(); + + // Register the shared library to the current device. + void registerLibOnWorker(__tgt_bin_desc *Desc); + + __tgt_target_table *loadBinaryOnWorker(const __tgt_device_image *Image); + + // Plugin and device information. + // =========================================================================== +public: + bool isValidDeviceId(const int DeviceId) const; + + int getNumOfDevices() const; + + // Valid check helpers. + // =========================================================================== +private: + bool checkValidDeviceId(const int DeviceId) const; + + bool checkValidAsyncInfo(const __tgt_async_info *AsyncInfo) const; + + int32_t checkCreatedEvent(const EventPtr &Event) const; + + bool checkRecordedEventPtr(const void *Event) const; + + // Data management. + // =========================================================================== +public: + void *dataAlloc(int32_t DeviceId, int64_t Size, void *HostPtr, + TargetAllocTy Kind); + + int32_t dataDelete(int32_t DeviceId, void *TargetPtr, TargetAllocTy Kind); + + int32_t dataSubmit(int32_t DeviceId, void *TargetPtr, void *HostPtr, + int64_t Size, __tgt_async_info *AsyncInfo); + + int32_t dataRetrieve(int32_t DeviceId, void *HostPtr, void *TargetPtr, + int64_t Size, __tgt_async_info *AsyncInfo); + + int32_t dataExchange(int32_t SrcID, void *SrcPtr, int32_t DstId, void *DstPtr, + int64_t Size, __tgt_async_info *AsyncInfo); + + // Target execution. + // =========================================================================== +public: + int32_t runTargetRegion(int32_t DeviceId, void *Entry, void **Args, + ptrdiff_t *Offsets, int32_t NumArgs, + __tgt_async_info *AsyncInfo); + + // Event queueing and synchronization. + // =========================================================================== +public: + int32_t synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfo); + +private: + // Acquire the async context from the async info object. If no context is + // present, a new one is created. + using EventQueue = SmallVector; + EventQueue *getEventQueue(__tgt_async_info *AsyncInfo); + + // Push a new event to the respective device queue, updating the async info + // context. + void pushNewEvent(const EventPtr &Event, __tgt_async_info *AsyncInfo); + + // Device side functions. + // =========================================================================== +public: + // Return true if currently being executed inside the device. + bool isInsideDevice(); + + // Start device main for worker ranks. + void runDeviceMain(__tgt_bin_desc *Desc); + + // External events management. + // =========================================================================== +public: + // Allocates a shared pointer to an event. + int32_t createEvent(int32_t DeviceId, void **Event); + + // Destroys a shared pointer to an event. + int32_t destroyEvent(int32_t DeviceId, void *Event); + + // Binds Event to the last internal event present in the event queue. + int32_t recordEvent(int32_t DeviceId, void *Event, + __tgt_async_info *AsyncInfo); + + // Adds the `Event` to the event queue so we can wait for it. `Event` might + // come from another device event queue (even on another task context), + // allowing two tasks to synchronize their inner events when needed (e.g.: + // wait for a data to be submitted). + int32_t waitEvent(int32_t DeviceId, void *Event, __tgt_async_info *AsyncInfo); + + // Waits for the Event, blocking the caller thread. + int32_t syncEvent(int32_t DeviceId, void *Event); +}; + +#endif // _OMPTARGET_OMPCLUSTER_MPI_MANAGER_H_ diff --git a/openmp/libomptarget/plugins/mpi/src/EventSystem.cpp b/openmp/libomptarget/plugins/mpi/src/EventSystem.cpp new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/mpi/src/EventSystem.cpp @@ -0,0 +1,1137 @@ +//===------ event_system.cpp - Concurrent MPI communication -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the MPI Event System used by the MPI +// target runtime for concurrent communication. +// +//===----------------------------------------------------------------------===// + +#include "EventSystem.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "Common.h" +#include "Coroutines.h" + +#include "omptarget.h" + +// Helper coroutine macros +#define EVENT_BEGIN() CO_BEGIN() +#define EVENT_PAUSE() CO_YIELD(false) +#define EVENT_PAUSE_FOR_REQUESTS() \ + while (!checkPendingRequests()) { \ + EVENT_PAUSE(); \ + } +#define EVENT_END() CO_RETURN(true) + +// Customizable parameters of the event system +// ============================================================================= +// Here we declare some configuration variables with their respective default +// values. Every single one of them can be tuned by an environment variable with +// the following name pattern: OMPCLUSTER_VAR_NAME. +namespace config { +// Maximum buffer Size to use during data transfer. +static int64_t MPI_FRAGMENT_SIZE = 100e6; +// Number of execute event handlers to spawn. +static int NUM_EXEC_EVENT_HANDLERS = 1; +// Number of data event handlers to spawn. +static int NUM_DATA_EVENT_HANDLERS = 1; +// Polling rate period (us) used by event handlers. +static int EVENT_POLLING_RATE = 1; +// Number of communicators to be spawned and distributed for the events. Allows +// for parallel use of network resources. +static int64_t NUM_EVENT_COMM = 10; +} // namespace config + +// Helper functions +// ============================================================================= +const char *toString(EventTypeTy Type) { + switch (Type) { + case EventTypeTy::ALLOC: + return "Alloc"; + case EventTypeTy::DELETE: + return "Delete"; + case EventTypeTy::RETRIEVE: + return "Retrieve"; + case EventTypeTy::SUBMIT: + return "Submit"; + case EventTypeTy::EXCHANGE: + return "Exchange"; + case EventTypeTy::EXECUTE: + return "Execute"; + case EventTypeTy::SYNC: + return "Sync"; + case EventTypeTy::EXIT: + return "Exit"; + } + + assertm(false, "Every enum value must be checked on the switch above."); + return nullptr; +} + +const char *toString(EventLocationTy Location) { + switch (Location) { + case EventLocationTy::DEST: + return "Destination"; + case EventLocationTy::ORIG: + return "Origin"; + } + + assertm(false, "Every enum value must be checked on the switch above."); + return nullptr; +} + +const char *toString(EventStateTy State) { + switch (State) { + case EventStateTy::CREATED: + return "Created"; + case EventStateTy::EXECUTING: + return "Executing"; + case EventStateTy::WAITING: + return "Waiting"; + case EventStateTy::FAILED: + return "Failed"; + case EventStateTy::FINISHED: + return "Finished"; + } + + assertm(false, "Every enum value must be checked on the switch above."); + return nullptr; +} + +// Base Event implementation +// ============================================================================= +BaseEventTy::BaseEventTy(EventLocationTy EventLocation, EventTypeTy EventType, + int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank) + : EventLocation(EventLocation), EventType(EventType), + TargetComm(TargetComm), MPITag(MPITag), OrigRank(OrigRank), + DestRank(DestRank) { + assertm(MPITag >= static_cast(ControlTagsTy::FIRST_EVENT), + "Event MPI tag must not have a Control Tag value"); + assertm(MPITag <= EventSystemTy::MPITagMaxValue, + "Event MPI tag must be smaller than the maximum value allowed"); +} + +void BaseEventTy::notifyNewEvent() { + if (EventLocation == EventLocationTy::ORIG) { + // Sends event request. + InitialRequestInfo[0] = static_cast(EventType); + InitialRequestInfo[1] = static_cast(MPITag); + + if (OrigRank != DestRank) + MPI_Isend(InitialRequestInfo, 2, MPI_UINT32_T, DestRank, + static_cast(ControlTagsTy::EVENT_REQUEST), + EventSystemTy::GateThreadComm, getNextRequest()); + } +} + +bool BaseEventTy::runCoroutine() { + if (EventLocation == EventLocationTy::ORIG) { + return runOrigin(); + } else { + return runDestination(); + } +} + +bool BaseEventTy::checkPendingRequests() { + int RequestsCompleted = false; + + MPI_Testall(PendingRequests.size(), PendingRequests.data(), + &RequestsCompleted, MPI_STATUSES_IGNORE); + + return RequestsCompleted; +} + +bool BaseEventTy::isDone() const { + return (EventState == EventStateTy::FINISHED) || + (EventState == EventStateTy::FAILED); +} + +void BaseEventTy::progress() { + // Immediately return if the event is already finished + if (isDone()) { + return; + } + + // The following code block uses a guard to ensure only one thread is + // executing the progress function at a time, returning immediately for all + // the other threads (e.g. multiple threads waiting on the same event). + // + // If one thread is already advancing the event execution at a node, there is + // no need for other threads to execute the progress function. Returning + // immediately frees other threads to execute other events/procedures and + // allows the events run* coroutines to be implemented in a not thead-safe + // manner. + bool ExpectedProgressGuard = false; + if (!ProgressGuard.compare_exchange_weak(ExpectedProgressGuard, true)) { + return; + } + + // Advance the event local execution depending on its state. + switch (EventState) { + case EventStateTy::CREATED: + notifyNewEvent(); + EventState = EventStateTy::EXECUTING; + [[fallthrough]]; + + case EventStateTy::EXECUTING: + if (!runCoroutine()) + break; + EventState = EventStateTy::WAITING; + [[fallthrough]]; + + case EventStateTy::WAITING: + if (!checkPendingRequests()) + break; + if (EventState != EventStateTy::FAILED) + EventState = EventStateTy::FINISHED; + break; + + case EventStateTy::FAILED: + REPORT("MPI event %s failed.\n", toString(EventType)); + [[fallthrough]]; + + case EventStateTy::FINISHED: + break; + } + + // Allow other threads to call progress again. + ProgressGuard = false; +} + +void BaseEventTy::wait() { + // Advance the event progress until it is completed. + while (!isDone()) { + progress(); + + std::this_thread::sleep_for( + std::chrono::microseconds(config::EVENT_POLLING_RATE)); + } +} + +EventStateTy BaseEventTy::getEventState() const { return EventState; } + +MPI_Request *BaseEventTy::getNextRequest() { + PendingRequests.emplace_back(MPI_REQUEST_NULL); + return &PendingRequests.back(); +} + +// Alloc Event implementation +// ============================================================================= +AllocEventTy::AllocEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank, int64_t Size, void **AllocatedAddress) + : BaseEventTy(EventLocationTy::ORIG, EventTypeTy::ALLOC, MPITag, TargetComm, + OrigRank, DestRank), + Size(Size), AllocatedAddressPtr(AllocatedAddress) { + assertm(Size >= 0, "AllocEvent must receive a Size >= 0"); + assertm(AllocatedAddress != nullptr, + "AllocEvent must receive a valid pointer as AllocatedAddress"); +} + +bool AllocEventTy::runOrigin() { + assert(EventLocation == EventLocationTy::ORIG); + + EVENT_BEGIN(); + + MPI_Isend(&Size, 1, MPI_INT64_T, DestRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Irecv(AllocatedAddressPtr, sizeof(uintptr_t), MPI_BYTE, DestRank, MPITag, + TargetComm, getNextRequest()); + + EVENT_END(); +} + +AllocEventTy::AllocEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank) + : BaseEventTy(EventLocationTy::DEST, EventTypeTy::ALLOC, MPITag, TargetComm, + OrigRank, DestRank) {} + +bool AllocEventTy::runDestination() { + assert(EventLocation == EventLocationTy::DEST); + + EVENT_BEGIN(); + + MPI_Irecv(&Size, 1, MPI_INT64_T, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_PAUSE_FOR_REQUESTS(); + + DestAddress = malloc(Size); + + MPI_Isend(&DestAddress, sizeof(uintptr_t), MPI_BYTE, OrigRank, MPITag, + TargetComm, getNextRequest()); + + EVENT_END(); +} + +// Delete Event implementation +// ============================================================================= +DeleteEventTy::DeleteEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank, void *TargetAddress) + : BaseEventTy(EventLocationTy::ORIG, EventTypeTy::DELETE, MPITag, + TargetComm, OrigRank, DestRank), + TargetAddress(TargetAddress) {} + +bool DeleteEventTy::runOrigin() { + assert(EventLocation == EventLocationTy::ORIG); + + EVENT_BEGIN(); + + MPI_Isend(&TargetAddress, sizeof(void *), MPI_BYTE, DestRank, MPITag, + TargetComm, getNextRequest()); + + // Event completion notification + MPI_Irecv(nullptr, 0, MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +DeleteEventTy::DeleteEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank) + : BaseEventTy(EventLocationTy::DEST, EventTypeTy::DELETE, MPITag, + TargetComm, OrigRank, DestRank) {} + +bool DeleteEventTy::runDestination() { + assert(EventLocation == EventLocationTy::DEST); + + EVENT_BEGIN(); + + MPI_Irecv(&TargetAddress, sizeof(void *), MPI_BYTE, OrigRank, MPITag, + TargetComm, getNextRequest()); + + EVENT_PAUSE_FOR_REQUESTS(); + + free(TargetAddress); + + // Event completion notification + MPI_Isend(nullptr, 0, MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +// Retrieve Event implementation +// ============================================================================= +RetrieveEventTy::RetrieveEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank, void *OrigPtr, + const void *DestPtr, int64_t Size) + : BaseEventTy(EventLocationTy::ORIG, EventTypeTy::RETRIEVE, MPITag, + TargetComm, OrigRank, DestRank), + OrigPtr(OrigPtr), DestPtr(DestPtr), Size(Size) { + assertm(Size >= 0, "RetrieveEvent must receive a Size >= 0"); + assertm(OrigPtr != nullptr, + "RetrieveEvent must receive a valid pointer as OrigPtr"); + assertm(DestPtr != nullptr, + "RetrieveEvent must receive a valid pointer as DestPtr"); +} + +bool RetrieveEventTy::runOrigin() { + assert(EventLocation == EventLocationTy::ORIG); + + char *BufferByteArray = nullptr; + int64_t RemainingBytes = 0; + + EVENT_BEGIN(); + + MPI_Isend(&DestPtr, sizeof(uintptr_t), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Isend(&Size, sizeof(int64_t), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + // TODO: Extract this to an common function for both dest/orig for + // submit/retrieve. + // Operates over many fragments of the original buffer of at + // most MPI_FRAGMENT_SIZE bytes. + BufferByteArray = reinterpret_cast(OrigPtr); + RemainingBytes = Size; + while (RemainingBytes > 0) { + MPI_Irecv( + &BufferByteArray[Size - RemainingBytes], + static_cast(std::min(RemainingBytes, config::MPI_FRAGMENT_SIZE)), + MPI_BYTE, DestRank, MPITag, TargetComm, getNextRequest()); + RemainingBytes -= config::MPI_FRAGMENT_SIZE; + } + + EVENT_END(); +} + +RetrieveEventTy::RetrieveEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank) + : BaseEventTy(EventLocationTy::DEST, EventTypeTy::RETRIEVE, MPITag, + TargetComm, OrigRank, DestRank) {} + +bool RetrieveEventTy::runDestination() { + assert(EventLocation == EventLocationTy::DEST); + + const char *BufferByteArray = nullptr; + int64_t RemainingBytes = 0; + + EVENT_BEGIN(); + + MPI_Irecv(&DestPtr, sizeof(uintptr_t), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Irecv(&Size, sizeof(int64_t), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_PAUSE_FOR_REQUESTS(); + + // Operates over many fragments of the original buffer of at most + // MPI_FRAGMENT_SIZE bytes. + BufferByteArray = reinterpret_cast(DestPtr); + RemainingBytes = Size; + while (RemainingBytes > 0) { + MPI_Isend( + &BufferByteArray[Size - RemainingBytes], + static_cast(std::min(RemainingBytes, config::MPI_FRAGMENT_SIZE)), + MPI_BYTE, OrigRank, MPITag, TargetComm, getNextRequest()); + RemainingBytes -= config::MPI_FRAGMENT_SIZE; + } + + EVENT_END(); +} + +// Submit Event implementation +// ============================================================================= +SubmitEventTy::SubmitEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank, const void *OrigPtr, void *DestPtr, + int64_t Size) + : BaseEventTy(EventLocationTy::ORIG, EventTypeTy::SUBMIT, MPITag, + TargetComm, OrigRank, DestRank), + OrigPtr(OrigPtr), DestPtr(DestPtr), Size(Size) { + assertm(Size >= 0, "SubmitEvent must receive a Size >= 0"); + assertm(OrigPtr != nullptr, + "SubmitEvent must receive a valid pointer as OrigPtr"); + assertm(DestPtr != nullptr, + "SubmitEvent must receive a valid pointer as DestPtr"); +} + +bool SubmitEventTy::runOrigin() { + assert(EventLocation == EventLocationTy::ORIG); + + const char *BufferByteArray; + int64_t RemainingBytes; + + EVENT_BEGIN(); + + MPI_Isend(&DestPtr, sizeof(uintptr_t), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Isend(&Size, sizeof(int64_t), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + // Operates over many fragments of the original buffer of at most + // MPI_FRAGMENT_SIZE bytes. + BufferByteArray = reinterpret_cast(OrigPtr); + RemainingBytes = Size; + while (RemainingBytes > 0) { + MPI_Isend( + &BufferByteArray[Size - RemainingBytes], + static_cast(std::min(RemainingBytes, config::MPI_FRAGMENT_SIZE)), + MPI_BYTE, DestRank, MPITag, TargetComm, getNextRequest()); + RemainingBytes -= config::MPI_FRAGMENT_SIZE; + } + + // Event completion notification + MPI_Irecv(nullptr, 0, MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +SubmitEventTy::SubmitEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank) + : BaseEventTy(EventLocationTy::DEST, EventTypeTy::SUBMIT, MPITag, + TargetComm, OrigRank, DestRank) {} + +bool SubmitEventTy::runDestination() { + assert(EventLocation == EventLocationTy::DEST); + + char *BufferByteArray = nullptr; + int64_t RemainingBytes = 0; + + EVENT_BEGIN(); + + MPI_Irecv(&DestPtr, sizeof(uintptr_t), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Irecv(&Size, sizeof(int64_t), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_PAUSE_FOR_REQUESTS(); + + // Operates over many fragments of the original buffer of at most + // MPI_FRAGMENT_SIZE bytes. + BufferByteArray = reinterpret_cast(DestPtr); + RemainingBytes = Size; + while (RemainingBytes > 0) { + MPI_Irecv( + &BufferByteArray[Size - RemainingBytes], + static_cast(std::min(RemainingBytes, config::MPI_FRAGMENT_SIZE)), + MPI_BYTE, OrigRank, MPITag, TargetComm, getNextRequest()); + RemainingBytes -= config::MPI_FRAGMENT_SIZE; + } + + // Event completion notification + MPI_Isend(nullptr, 0, MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +// Exchange Event implementation +// ============================================================================= +ExchangeEventTy::ExchangeEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank, int data_dst_rank, + const void *src_ptr, void *dst_ptr, + int64_t Size) + : BaseEventTy(EventLocationTy::ORIG, EventTypeTy::EXCHANGE, MPITag, + TargetComm, OrigRank, DestRank), + DataDestRank(data_dst_rank), SrcPtr(src_ptr), DstPtr(dst_ptr), + Size(Size) { + assertm(Size >= 0, "ExchangeEvent must receive a Size >= 0"); + assertm(src_ptr != nullptr, + "ExchangeEvent must receive a valid pointer as src_ptr"); + assertm(dst_ptr != nullptr, + "ExchangeEvent must receive a valid pointer as dst_ptr"); +} + +bool ExchangeEventTy::runOrigin() { + assert(EventLocation == EventLocationTy::ORIG); + + EVENT_BEGIN(); + + MPI_Isend(&DataDestRank, sizeof(int), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Isend(&SrcPtr, sizeof(uintptr_t), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Isend(&DstPtr, sizeof(uintptr_t), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Isend(&Size, sizeof(int64_t), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + // Event completion notification + MPI_Irecv(nullptr, 0, MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +ExchangeEventTy::ExchangeEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank) + : BaseEventTy(EventLocationTy::DEST, EventTypeTy::EXCHANGE, MPITag, + TargetComm, OrigRank, DestRank), + DataDestRank(-1), SrcPtr(nullptr), DstPtr(nullptr), Size(0) {} + +bool ExchangeEventTy::runDestination() { + assert(EventLocation == EventLocationTy::DEST); + + EVENT_BEGIN(); + + MPI_Irecv(&DataDestRank, sizeof(int), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Irecv(&SrcPtr, sizeof(uintptr_t), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Irecv(&DstPtr, sizeof(uintptr_t), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Irecv(&Size, sizeof(int64_t), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_PAUSE_FOR_REQUESTS(); + + RemoteSubmitEvent = std::make_shared( + MPITag, TargetComm, DestRank, DataDestRank, SrcPtr, DstPtr, Size); + + do { + RemoteSubmitEvent->progress(); + + if (!RemoteSubmitEvent->isDone()) { + EVENT_PAUSE(); + } + } while (!RemoteSubmitEvent->isDone()); + + // Event completion notification + MPI_Isend(nullptr, 0, MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +// Execute Event implementation +// ============================================================================= +ExecuteEventTy::ExecuteEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank, int32_t NumArgs, void **ArgsArray, + uint32_t TargetEntryIdx) + : BaseEventTy(EventLocationTy::ORIG, EventTypeTy::EXECUTE, MPITag, + TargetComm, OrigRank, DestRank), + NumArgs(NumArgs), Args(NumArgs, nullptr), TargetEntryIdx(TargetEntryIdx) { + assertm(NumArgs >= 0, "ExecuteEvent must receive an NumArgs >= 0"); + assertm(NumArgs == 0 || ArgsArray != nullptr, + "ExecuteEvent must receive a valid Args when NumArgs > 0"); + + std::copy_n(ArgsArray, NumArgs, Args.begin()); +} + +bool ExecuteEventTy::runOrigin() { + assert(EventLocation == EventLocationTy::ORIG); + + EVENT_BEGIN(); + + MPI_Isend(&NumArgs, sizeof(int32_t), MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + MPI_Isend(Args.data(), NumArgs * sizeof(uintptr_t), MPI_BYTE, DestRank, + MPITag, TargetComm, getNextRequest()); + + MPI_Isend(&TargetEntryIdx, sizeof(uint32_t), MPI_BYTE, DestRank, MPITag, + TargetComm, getNextRequest()); + + // Event completion notification + MPI_Irecv(nullptr, 0, MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +ExecuteEventTy::ExecuteEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank, __tgt_target_table *TargetTable) + : BaseEventTy(EventLocationTy::DEST, EventTypeTy::EXECUTE, MPITag, + TargetComm, OrigRank, DestRank), + TargetTable(TargetTable) { + assertm(TargetTable != nullptr, + "ExecuteEvent must receive a valid pointer as TargetTable"); +} + +bool ExecuteEventTy::runDestination() { + assert(EventLocation == EventLocationTy::DEST); + + __tgt_offload_entry *Begin = nullptr; + __tgt_offload_entry *End = nullptr; + __tgt_offload_entry *Curr = nullptr; + ffi_cif Cif{}; + llvm::SmallVector ArgsTypes{}; + ffi_status FFIStatus [[maybe_unused]] = FFI_OK; + void (*TargetEntry)(void) = nullptr; + + EVENT_BEGIN(); + + MPI_Irecv(&NumArgs, sizeof(int32_t), MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_PAUSE_FOR_REQUESTS(); + + Args.resize(NumArgs, nullptr); + ArgsTypes.resize(NumArgs, &ffi_type_pointer); + MPI_Irecv(Args.data(), NumArgs * sizeof(uintptr_t), MPI_BYTE, OrigRank, + MPITag, TargetComm, getNextRequest()); + + MPI_Irecv(&TargetEntryIdx, sizeof(uint32_t), MPI_BYTE, OrigRank, MPITag, + TargetComm, getNextRequest()); + + EVENT_PAUSE_FOR_REQUESTS(); + + // Iterates over all the host table entries to see if we can locate the + // host_ptr. + Begin = TargetTable->EntriesBegin; + End = TargetTable->EntriesEnd; + Curr = Begin; + + // Iterates over all the table entries to see if we can locate the entry. + for (uint32_t I = 0; Curr < End; ++Curr, ++I) { + if (I == TargetEntryIdx) { + // We got a match, now fill the HostPtrToTableMap so that we may avoid + // this search next time. + *((void **)&TargetEntry) = Curr->addr; + break; + } + } + + // Return failure when entry not found. + assertm(Curr != End, "Could not find the right entry"); + + FFIStatus = ffi_prep_cif(&Cif, FFI_DEFAULT_ABI, NumArgs, &ffi_type_void, + &ArgsTypes[0]); + + assertm(FFIStatus == FFI_OK, "Unable to prepare target launch!"); + + ffi_call(&Cif, TargetEntry, NULL, &Args[0]); + + // Event completion notification + MPI_Isend(nullptr, 0, MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +// Sync Event implementation +// ============================================================================= +SyncEventTy::SyncEventTy(EventPtr &target_event) + : BaseEventTy(EventLocationTy::DEST, EventTypeTy::SYNC, + EventSystemTy::MPITagMaxValue, 0, 0, 0), + TargetEvent(target_event) {} + +bool SyncEventTy::runOrigin() { return true; } + +bool SyncEventTy::runDestination() { + EVENT_BEGIN(); + + while (!TargetEvent->isDone()) { + EVENT_PAUSE(); + } + + EVENT_END(); +} + +// Exit Event implementation +// ============================================================================= +ExitEventTy::ExitEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank) + : BaseEventTy(EventLocationTy::ORIG, EventTypeTy::EXIT, MPITag, TargetComm, + OrigRank, DestRank) {} + +bool ExitEventTy::runOrigin() { + assert(EventLocation == EventLocationTy::ORIG); + + EVENT_BEGIN(); + + // Event completion notification + MPI_Irecv(nullptr, 0, MPI_BYTE, DestRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +ExitEventTy::ExitEventTy(int MPITag, MPI_Comm TargetComm, int OrigRank, + int DestRank, + std::atomic *EventSystemState) + : BaseEventTy(EventLocationTy::DEST, EventTypeTy::EXIT, MPITag, TargetComm, + OrigRank, DestRank), + EventSystemState(EventSystemState) {} + +bool ExitEventTy::runDestination() { + assert(EventLocation == EventLocationTy::DEST); + + EventSystemStateTy OldState; + + EVENT_BEGIN(); + + OldState = EventSystemState->exchange(EventSystemStateTy::EXITED); + assertm(OldState != EventSystemStateTy::EXITED, + "Exit event received multiple times"); + + // Event completion notification + MPI_Isend(nullptr, 0, MPI_BYTE, OrigRank, MPITag, TargetComm, + getNextRequest()); + + EVENT_END(); +} + +// Event Queue implementation +// ============================================================================= +EventQueue::EventQueue() : Queue(), QueueMtx(), CanPopCv() {} + +size_t EventQueue::size() { + std::lock_guard lock(QueueMtx); + return Queue.size(); +} + +void EventQueue::push(EventPtr &Event) { + { + std::unique_lock lock(QueueMtx); + Queue.push(Event); + } + + // Notifies a thread possibly blocked by an empty queue. + CanPopCv.notify_one(); +} + +EventPtr EventQueue::pop() { + EventPtr TargetEvent = nullptr; + + { + std::unique_lock lock(QueueMtx); + + // Waits for at least one item to be pushed. + while (Queue.empty()) { + const bool has_new_event = CanPopCv.wait_for( + lock, std::chrono::microseconds(config::EVENT_POLLING_RATE), + [&] { return !Queue.empty(); }); + + if (!has_new_event) { + return nullptr; + } + } + + assertm(!Queue.empty(), "Queue was empty on pop operation."); + + TargetEvent = Queue.front(); + Queue.pop(); + } + + return TargetEvent; +} + +// Event System implementation +// ============================================================================= +// Event System statics. +MPI_Comm EventSystemTy::GateThreadComm = MPI_COMM_NULL; +int32_t EventSystemTy::MPITagMaxValue = 0; + +EventSystemTy::EventSystemTy() : EventSystemState(EventSystemStateTy::CREATED) { + // Read environment parameters + if (const char *env_str = std::getenv("OMPCLUSTER_MPI_FRAGMENT_SIZE")) { + config::MPI_FRAGMENT_SIZE = std::stoi(env_str); + assertm(config::MPI_FRAGMENT_SIZE >= 1, + "Maximum MPI buffer Size must be a least 1"); + assertm(config::MPI_FRAGMENT_SIZE < std::numeric_limits::max(), + "Maximum MPI buffer Size must be less then the largest int " + "value (MPI restrictions)"); + } + + if (const char *env_str = std::getenv("OMPCLUSTER_NUM_EXEC_EVENT_HANDLERS")) { + config::NUM_EXEC_EVENT_HANDLERS = std::stoi(env_str); + assertm(config::NUM_EXEC_EVENT_HANDLERS >= 1, + "At least one exec event handler should be spawned"); + } + + if (const char *env_str = std::getenv("OMPCLUSTER_NUM_DataEventHandlers")) { + config::NUM_DATA_EVENT_HANDLERS = std::stoi(env_str); + assertm(config::NUM_DATA_EVENT_HANDLERS >= 1, + "At least one data event handler should be spawned"); + } + + if (const char *env_str = std::getenv("OMPCLUSTER_EVENT_POLLING_RATE")) { + config::EVENT_POLLING_RATE = std::stoi(env_str); + assertm(config::EVENT_POLLING_RATE >= 0, + "Event system polling rate should not be negative"); + } + + if (const char *env_str = std::getenv("OMPCLUSTER_NUM_EVENT_COMM")) { + config::NUM_EVENT_COMM = std::stoi(env_str); + assertm(config::NUM_EVENT_COMM >= 1, + "At least on communicator need to be spawned"); + } +} + +EventSystemTy::~EventSystemTy() { + if (!IsInitialized) + return; + + REPORT("Destructing internal event system before deinitializing it.\n"); + deinitialize(); +} + +bool EventSystemTy::initialize() { + if (IsInitialized) { + REPORT("Trying to initialize event system twice.\n"); + return false; + } + + if (!createLocalMPIContext()) + return false; + + IsInitialized = true; + + return true; +} + +bool EventSystemTy::deinitialize() { + if (!IsInitialized) { + REPORT("Trying to deinitialize event system twice.\n"); + return false; + } + + // Only send exit events from the host side + if (isHead() && WorldSize > 1) { + const int NumWorkers = WorldSize - 1; + llvm::SmallVector ExitEvents(NumWorkers); + for (int WorkerRank = 0; WorkerRank < NumWorkers; WorkerRank++) { + ExitEvents[WorkerRank] = createEvent(WorkerRank); + ExitEvents[WorkerRank]->progress(); + } + + bool SuccessfullyExited = true; + for (int WorkerRank = 0; WorkerRank < NumWorkers; WorkerRank++) { + ExitEvents[WorkerRank]->wait(); + SuccessfullyExited &= + ExitEvents[WorkerRank]->getEventState() == EventStateTy::FINISHED; + } + + if (!SuccessfullyExited) { + REPORT("Failed to stop worker processes.\n"); + return false; + } + } + + if (!destroyLocalMPIContext()) + return false; + + IsInitialized = false; + + return true; +} + +void EventSystemTy::runEventHandler(EventQueue &Queue) { + while (EventSystemState == EventSystemStateTy::RUNNING || Queue.size() > 0) { + EventPtr event = Queue.pop(); + + // Re-checks the stop condition when no event was found. + if (event == nullptr) { + continue; + } + + event->progress(); + + if (!event->isDone()) { + Queue.push(event); + } + } +} + +void EventSystemTy::runGateThread(__tgt_target_table *TargetTable) { + // Updates the event state and + EventSystemState = EventSystemStateTy::RUNNING; + + // Spawns the event handlers. + llvm::SmallVector EventHandlers; + EventHandlers.resize(config::NUM_EXEC_EVENT_HANDLERS + + config::NUM_DATA_EVENT_HANDLERS); + for (int Idx = 0; Idx < EventHandlers.size(); Idx++) { + EventHandlers[Idx] = std::thread( + &EventSystemTy::runEventHandler, this, + std::ref(Idx < config::NUM_EXEC_EVENT_HANDLERS ? ExecEventQueue + : DataEventQueue)); + } + + // Executes the gate thread logic + while (EventSystemState == EventSystemStateTy::RUNNING) { + // Checks for new incoming event requests. + MPI_Message EventReqMsg; + MPI_Status EventStatus; + int HasReceived = false; + MPI_Improbe(MPI_ANY_SOURCE, static_cast(ControlTagsTy::EVENT_REQUEST), + GateThreadComm, &HasReceived, &EventReqMsg, MPI_STATUS_IGNORE); + + // If none was received, wait for `EVENT_POLLING_RATE`us for the next + // check. + if (!HasReceived) { + std::this_thread::sleep_for( + std::chrono::microseconds(config::EVENT_POLLING_RATE)); + continue; + } + + // Acquires the event information from the received request, which are: + // - Event type + // - Event tag + // - Target comm + // - Event source rank + uint32_t EventInfo[2]; + MPI_Mrecv(EventInfo, 2, MPI_UINT32_T, &EventReqMsg, &EventStatus); + const auto NewEventType = static_cast(EventInfo[0]); + const uint32_t NewEventTag = EventInfo[1]; + auto &NewEventComm = getNewEventComm(NewEventTag); + const int OrigRank = EventStatus.MPI_SOURCE; + + // Creates a new receive event of 'event_type' type. + EventPtr NewEvent; + switch (NewEventType) { + case EventTypeTy::ALLOC: + NewEvent = std::make_shared(NewEventTag, NewEventComm, + OrigRank, LocalRank); + break; + case EventTypeTy::DELETE: + NewEvent = std::make_shared(NewEventTag, NewEventComm, + OrigRank, LocalRank); + break; + case EventTypeTy::RETRIEVE: + NewEvent = std::make_shared(NewEventTag, NewEventComm, + OrigRank, LocalRank); + break; + case EventTypeTy::SUBMIT: + NewEvent = std::make_shared(NewEventTag, NewEventComm, + OrigRank, LocalRank); + break; + case EventTypeTy::EXCHANGE: + NewEvent = std::make_shared(NewEventTag, NewEventComm, + OrigRank, LocalRank); + break; + case EventTypeTy::EXECUTE: + NewEvent = std::make_shared( + NewEventTag, NewEventComm, OrigRank, LocalRank, TargetTable); + break; + case EventTypeTy::EXIT: + NewEvent = std::make_shared( + NewEventTag, NewEventComm, OrigRank, LocalRank, &EventSystemState); + break; + case EventTypeTy::SYNC: + assertm(false, "Trying to create a local event on a remote node"); + } + + assertm(NewEvent != nullptr, "Created event must not be a nullptr"); + assertm(NewEvent->EventLocation == EventLocationTy::DEST, + "Gate thread must receive only receive events"); + + if (NewEventType == EventTypeTy::EXECUTE) { + ExecEventQueue.push(NewEvent); + } else { + DataEventQueue.push(NewEvent); + } + } + + assertm(EventSystemState == EventSystemStateTy::EXITED, + "Event State should be EXITED after receiving an Exit event"); + + // Waits for the Event Handler threads. + for (auto &EventHandler : EventHandlers) { + if (EventHandler.joinable()) { + EventHandler.join(); + } else { + assertm(false, "Event Handler threads not joinable at the end of gate " + "thread logic."); + } + } +} + +// Creates a new event tag of at least 'FIRST_EVENT' value. +// Tag values smaller than 'FIRST_EVENT' are reserved for control +// communication between the event systems of different MPI processes. +int EventSystemTy::createNewEventTag() { + uint32_t tag = 0; + + do { + tag = EventCounter.fetch_add(1) % MPITagMaxValue; + } while (tag < static_cast(ControlTagsTy::FIRST_EVENT)); + + return tag; +} + +MPI_Comm &EventSystemTy::getNewEventComm(int MPITag) { + // Retrieve a comm using a round-robin strategy around the event's mpi tag. + return EventCommPool[MPITag % EventCommPool.size()]; +} + +static const char *threadLevelToString(int ThreadLevel) { + switch (ThreadLevel) { + case MPI_THREAD_SINGLE: + return "MPI_THREAD_SINGLE"; + case MPI_THREAD_SERIALIZED: + return "MPI_THREAD_SERIALIZED"; + case MPI_THREAD_FUNNELED: + return "MPI_THREAD_FUNNELED"; + case MPI_THREAD_MULTIPLE: + return "MPI_THREAD_MULTIPLE"; + default: + return "unkown"; + } +} + +bool EventSystemTy::createLocalMPIContext() { + int MPIError = MPI_SUCCESS; + + // Initialize the MPI context. + int IsMPIInitialized = 0; + int ThreadLevel = MPI_THREAD_SINGLE; + + MPI_Initialized(&IsMPIInitialized); + + if (IsMPIInitialized) + MPI_Query_thread(&ThreadLevel); + else + MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ThreadLevel); + + CHECK(ThreadLevel == MPI_THREAD_MULTIPLE, + "MPI plugin requires a MPI implementation with %s thread level. " + "Implementation only supports up to %s.\n", + threadLevelToString(MPI_THREAD_MULTIPLE), + threadLevelToString(ThreadLevel)); + + // Create gate thread comm. + MPIError = MPI_Comm_dup(MPI_COMM_WORLD, &GateThreadComm); + CHECK(MPIError == MPI_SUCCESS, + "Failed to create gate thread MPI comm with error %d\n", MPIError); + + // Create event comm pool. + EventCommPool.resize(config::NUM_EVENT_COMM, MPI_COMM_NULL); + for (auto &Comm : EventCommPool) { + MPI_Comm_dup(MPI_COMM_WORLD, &Comm); + CHECK(MPIError == MPI_SUCCESS, + "Failed to create MPI comm pool with error %d\n", MPIError); + } + + // Get local MPI process description. + MPIError = MPI_Comm_rank(GateThreadComm, &LocalRank); + CHECK(MPIError == MPI_SUCCESS, + "Failed to acquire the local MPI rank with error %d\n", MPIError); + + MPIError = MPI_Comm_size(GateThreadComm, &WorldSize); + CHECK(MPIError == MPI_SUCCESS, + "Failed to acquire the world size with error %d\n", MPIError); + + // Get max value for MPI tags. + MPI_Aint *Value = nullptr; + int Flag = 0; + MPIError = MPI_Comm_get_attr(GateThreadComm, MPI_TAG_UB, &Value, &Flag); + CHECK(Flag == 1 && MPIError == MPI_SUCCESS, + "Failed to acquire the MPI max tag value with error %d\n", MPIError); + MPITagMaxValue = *Value; + + return true; +} + +bool EventSystemTy::destroyLocalMPIContext() { + int MPIError = MPI_SUCCESS; + + // Note: We don't need to assert here since application part of the program + // was finished. + // Free gate thread comm. + MPIError = MPI_Comm_free(&GateThreadComm); + CHECK(MPIError == MPI_SUCCESS, + "Failed to destroy the gate thread MPI comm with error %d\n", MPIError); + + // Free event comm pool. + for (auto &comm : EventCommPool) { + MPI_Comm_free(&comm); + CHECK(MPIError == MPI_SUCCESS, + "Failed to destroy the event MPI comm with error %d\n", MPIError); + } + EventCommPool.resize(0); + + // Finalize the global MPI session. + int IsFinalized = false; + MPIError = MPI_Finalized(&IsFinalized); + + if (IsFinalized) { + DP("MPI was already finalized externally.\n"); + } else { + MPIError = MPI_Finalize(); + CHECK(MPIError == MPI_SUCCESS, "Failed to finalize MPI with error: %d\n", + MPIError); + } + + return true; +} + +int EventSystemTy::getNumWorkers() const { return WorldSize - 1; }; + +int EventSystemTy::isHead() const { return LocalRank == 0; }; diff --git a/openmp/libomptarget/plugins/mpi/src/MPIManager.cpp b/openmp/libomptarget/plugins/mpi/src/MPIManager.cpp new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/mpi/src/MPIManager.cpp @@ -0,0 +1,620 @@ +//===------RTLs/mpi/src/rtl.cpp - Target RTLs Implementation - C++ ------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// RTL for MPI machine +// +//===----------------------------------------------------------------------===// + +#include "MPIManager.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// ELF utilities definitions. +// ============================================================================= +#ifndef TARGET_ELF_ID +#define TARGET_ELF_ID 0 +#endif +#include "elf_common.h" + +// Memory Allocator. +// ============================================================================= +constexpr const char *toString(const TargetAllocTy &Kind) { + switch (Kind) { + case TargetAllocTy::TARGET_ALLOC_DEFAULT: + return "TARGET_ALLOC_DEFAULT"; + case TargetAllocTy::TARGET_ALLOC_DEVICE: + return "TARGET_ALLOC_DEVICE"; + case TargetAllocTy::TARGET_ALLOC_HOST: + return "TARGET_ALLOC_HOST"; + case TargetAllocTy::TARGET_ALLOC_SHARED: + return "TARGET_ALLOC_SHARED"; + } +} + +void *MPIManagerTy::MPIDeviceAllocatorTy::allocate(size_t Size, void *HostPtr, + TargetAllocTy Kind) { + if (Kind != TargetAllocTy::TARGET_ALLOC_DEFAULT && + Kind != TargetAllocTy::TARGET_ALLOC_DEVICE) { + REPORT("Invalid allocation kind %s. MPI plugin only supports " + "TARGET_ALLOC_DEFAULT and TARGET_ALLOC_DEVICE.", + toString(Kind)); + return nullptr; + } + + if (Size == 0) + return nullptr; + + void *DevicePtr = nullptr; + auto event = + EventSystem.createEvent(DeviceId, Size, &DevicePtr); + event->wait(); + + if (event->getEventState() == EventStateTy::FAILED) + return nullptr; + + return DevicePtr; +} + +int MPIManagerTy::MPIDeviceAllocatorTy::free(void *TargetPtr, + TargetAllocTy Kind) { + if (Kind != TargetAllocTy::TARGET_ALLOC_DEFAULT && + Kind != TargetAllocTy::TARGET_ALLOC_DEVICE) { + REPORT("Invalid allocation kind %s. MPI plugin only supports " + "TARGET_ALLOC_DEFAULT and TARGET_ALLOC_DEVICE.", + toString(Kind)); + return OFFLOAD_FAIL; + } + + auto event = EventSystem.createEvent(DeviceId, TargetPtr); + event->wait(); + + if (event->getEventState() == EventStateTy::FAILED) + return OFFLOAD_FAIL; + + return OFFLOAD_SUCCESS; +} + +// MPI Manager. +// ============================================================================= +// De/initialization functions. +// ============================================================================= +MPIManagerTy::~MPIManagerTy() { + if (!IsInitialized) + return; + + REPORT("Destructing internal plugin manager before deinitializing it.\n"); + deinitialize(); +} + +bool MPIManagerTy::initialize() { + if (IsInitialized) { + REPORT("Trying to initialize MPI plugin manager twice.\n"); + return false; + } + + if (!EventSystem.initialize()) + return false; + + const int NumWorkers = EventSystem.getNumWorkers(); + + // Set function entries + FuncGblEntries.resize(EventSystem.isHead() ? EventSystem.getNumWorkers() : 1); + + // Set device allocators and the memory manager. + for (int DeviceId = 0; DeviceId < NumWorkers; ++DeviceId) + ProcessAllocators.emplace_back(DeviceId, EventSystem); + + auto [ManagerThreshold, ManagerEnabled] = + MemoryManagerTy::getSizeThresholdFromEnv(); + UseMemoryManager = ManagerEnabled; + + if (UseMemoryManager) { + for (int i = 0; i < NumWorkers; ++i) + MemoryManagers.emplace_back(std::make_unique( + ProcessAllocators[i], ManagerThreshold)); + } + + IsInitialized = true; + + return true; +} + +bool MPIManagerTy::deinitialize() { + if (!IsInitialized) { + REPORT("Trying to deinitialize MPI plugin manager twice.\n"); + return false; + } + + // Close dynamic libraries. + for (auto &Lib : DynLibs) { + if (Lib.DynLib->isValid()) + remove(Lib.FileName.c_str()); + } + + // Destruct memory managers. + for (auto &m : MemoryManagers) + m.release(); + + if (!EventSystem.deinitialize()) + return false; + + IsInitialized = false; + + return true; +} + +void MPIManagerTy::createOffloadTable( + int32_t DeviceId, SmallVector<__tgt_offload_entry> &&Entries) { + assert(DeviceId < (int32_t)FuncGblEntries.size() && "Unexpected device id!"); + FuncGblEntries[DeviceId].emplace_back(); + FuncOrGblEntryTy &E = FuncGblEntries[DeviceId].back(); + + E.Entries = Entries; + E.Table.EntriesBegin = E.Entries.begin(); + E.Table.EntriesEnd = E.Entries.end(); +} + +bool MPIManagerTy::findOffloadEntry(int32_t DeviceId, void *Addr) { + assert(DeviceId < (int32_t)FuncGblEntries.size() && "Unexpected device id!"); + FuncOrGblEntryTy &E = FuncGblEntries[DeviceId].back(); + + for (__tgt_offload_entry *I = E.Table.EntriesBegin, *End = E.Table.EntriesEnd; + I < End; ++I) { + if (I->addr == Addr) + return true; + } + + return false; +} + +__tgt_target_table *MPIManagerTy::getOffloadEntriesTable(int32_t DeviceId) { + assert(DeviceId < (int32_t)FuncGblEntries.size() && "Unexpected device id!"); + FuncOrGblEntryTy &E = FuncGblEntries[DeviceId].back(); + + return &E.Table; +} + +__tgt_target_table *MPIManagerTy::getOffloadEntriesTableOnWorker() { + return getOffloadEntriesTable(0); +} + +void MPIManagerTy::registerLibOnWorker(__tgt_bin_desc *Desc) { + // Register the images with the RTLs that understand them, if any. + for (int32_t I = 0; I < Desc->NumDeviceImages; ++I) { + __tgt_device_image *Img = &Desc->DeviceImages[I]; + + if (!isValidBinary(Img)) { + DP("Image " DPxMOD " is NOT compatible with this MPI device!\n", + DPxPTR(Img->ImageStart)); + continue; + } + + DP("Image " DPxMOD " is compatible with this MPI device!\n", + DPxPTR(Img->ImageStart)); + + loadBinaryOnWorker(Img); + } + + DP("Done registering entries!\n"); +} + +int32_t MPIManagerTy::isValidBinary(__tgt_device_image *Image) const { +// If we don't have a valid ELF ID we can just fail. +#if TARGET_ELF_ID < 1 + return 0; +#else + return elf_check_machine(Image, TARGET_ELF_ID); +#endif +} + +__tgt_target_table *MPIManagerTy::loadBinary(const int DeviceId, + const __tgt_device_image *Image) { + DP("Dev %d: load binary from " DPxMOD " image\n", DeviceId, + DPxPTR(Image->ImageStart)); + + if (!checkValidDeviceId(DeviceId)) { + REPORT("Trying to load a binary into an invalid device ID %d\n", DeviceId); + return nullptr; + } + + size_t ImageSize = (size_t)Image->ImageEnd - (size_t)Image->ImageStart; + + // load dynamic library and get the entry points. We use the dl library + // to do the loading of the library, but we could do it directly to avoid the + // dump to the temporary file. + // + // 1) Create tmp file with the library contents. + // 2) Use dlopen to load the file and dlsym to retrieve the symbols. + char TmpName[] = "/tmp/tmpfile_XXXXXX"; + int TmpFd = mkstemp(TmpName); + + if (TmpFd == -1) { + REPORT("Failed to load binary. Failed to create temporary file %s.\n", + TmpName); + return nullptr; + } + + FILE *TmpFile = fdopen(TmpFd, "wb"); + + if (!TmpFile) { + REPORT("Failed to load binary. Failed to open new temporary file %s.\n", + TmpName); + return nullptr; + } + + fwrite(Image->ImageStart, ImageSize, 1, TmpFile); + fclose(TmpFile); + + std::string ErrMsg; + auto DynLib = std::make_unique( + llvm::sys::DynamicLibrary::getPermanentLibrary(TmpName, &ErrMsg)); + DynLibTy Lib = {TmpName, std::move(DynLib)}; + + if (!Lib.DynLib->isValid()) { + REPORT("Failed to load binary. Loading error: %s.\n", ErrMsg.c_str()); + return nullptr; + } + + __tgt_offload_entry *HostBegin = Image->EntriesBegin; + __tgt_offload_entry *HostEnd = Image->EntriesEnd; + + // Create a new offloading entry list using the device symbol address. + SmallVector<__tgt_offload_entry> Entries; + for (__tgt_offload_entry *E = HostBegin; E != HostEnd; ++E) { + if (!E->addr) { + REPORT("Failed to load binary. Found null entry.\n"); + return nullptr; + } + + __tgt_offload_entry Entry = *E; + + void *DevAddr = Lib.DynLib->getAddressOfSymbol(E->name); + Entry.addr = DevAddr; + + DP("Entry point " DPxMOD " maps to global %s (" DPxMOD ")\n", + DPxPTR(E - HostBegin), E->name, DPxPTR(DevAddr)); + + Entries.emplace_back(Entry); + } + + createOffloadTable(DeviceId, std::move(Entries)); + DynLibs.emplace_back(std::move(Lib)); + + return getOffloadEntriesTable(DeviceId); +} + +__tgt_target_table * +MPIManagerTy::loadBinaryOnWorker(const __tgt_device_image *Image) { + return loadBinary(0, Image); +} + +bool MPIManagerTy::isValidDeviceId(const int DeviceId) const { + return DeviceId >= 0 && DeviceId < EventSystem.getNumWorkers(); +} + +int MPIManagerTy::getNumOfDevices() const { + return EventSystem.getNumWorkers(); +} + +bool MPIManagerTy::checkValidDeviceId(const int DeviceId) const { + if (!isValidDeviceId(DeviceId)) { + REPORT("Received device id %d out of range of valid ids [%d, %d]\n", + DeviceId, 0, EventSystem.getNumWorkers()); + return false; + } + + return true; +} + +bool MPIManagerTy::checkValidAsyncInfo( + const __tgt_async_info *AsyncInfo) const { + if (AsyncInfo == nullptr) { + REPORT("Plugin call failed. Received null AsyncInfo\n"); + return false; + } + + return true; +} + +int32_t MPIManagerTy::checkCreatedEvent(const EventPtr &Event) const { + if (Event->getEventState() != EventStateTy::CREATED) + return OFFLOAD_FAIL; + + return OFFLOAD_SUCCESS; +} + +bool MPIManagerTy::checkRecordedEventPtr(const void *Event) const { + if (!Event) { + REPORT("Received an invalid recorded event pointer\n"); + return false; + } + + return true; +} + +void *MPIManagerTy::dataAlloc(int32_t DeviceId, int64_t Size, void *HostPtr, + TargetAllocTy Kind) { + if (!checkValidDeviceId(DeviceId)) + return nullptr; + + if (UseMemoryManager) + return MemoryManagers[DeviceId]->allocate(Size, HostPtr); + + return ProcessAllocators[DeviceId].allocate(Size, HostPtr, Kind); +} + +int32_t MPIManagerTy::dataDelete(int32_t DeviceId, void *TargetPtr, + TargetAllocTy Kind) { + if (!checkValidDeviceId(DeviceId)) + return OFFLOAD_FAIL; + + if (UseMemoryManager) + return MemoryManagers[DeviceId]->free(TargetPtr); + + return ProcessAllocators[DeviceId].free(TargetPtr, Kind); +} + +int32_t MPIManagerTy::dataSubmit(int32_t DeviceId, void *TargetPtr, + void *HostPtr, int64_t Size, + __tgt_async_info *AsyncInfo) { + if (!checkValidDeviceId(DeviceId)) + return OFFLOAD_FAIL; + + if (!checkValidAsyncInfo(AsyncInfo)) + return OFFLOAD_FAIL; + + auto Event = EventSystem.createEvent(DeviceId, TargetPtr, + HostPtr, Size); + pushNewEvent(Event, AsyncInfo); + + return checkCreatedEvent(Event); +} + +int32_t MPIManagerTy::dataRetrieve(int32_t DeviceId, void *HostPtr, + void *TargetPtr, int64_t Size, + __tgt_async_info *AsyncInfo) { + if (!checkValidDeviceId(DeviceId)) + return OFFLOAD_FAIL; + + if (!checkValidAsyncInfo(AsyncInfo)) + return OFFLOAD_FAIL; + + auto Event = EventSystem.createEvent(DeviceId, HostPtr, + TargetPtr, Size); + pushNewEvent(Event, AsyncInfo); + + return checkCreatedEvent(Event); +} + +int32_t MPIManagerTy::dataExchange(int32_t SrcID, void *SrcPtr, int32_t DstId, + void *DstPtr, int64_t Size, + __tgt_async_info *AsyncInfo) { + if (!checkValidDeviceId(SrcID) || !checkValidDeviceId(DstId)) + return OFFLOAD_FAIL; + + if (!checkValidAsyncInfo(AsyncInfo)) + return OFFLOAD_FAIL; + + auto Event = EventSystem.createEvent(SrcID, DstId + 1, + SrcPtr, DstPtr, Size); + pushNewEvent(Event, AsyncInfo); + + return OFFLOAD_SUCCESS; +} + +int32_t MPIManagerTy::runTargetRegion(int32_t DeviceId, void *Entry, + void **Args, ptrdiff_t *Offsets, + int32_t NumArgs, + __tgt_async_info *AsyncInfo) { + if (!checkValidDeviceId(DeviceId)) + return OFFLOAD_FAIL; + + if (!checkValidAsyncInfo(AsyncInfo)) + return OFFLOAD_FAIL; + + // Prepare all args based on their offsets. + SmallVector ArgPtrs(NumArgs); + + for (int I = 0; I < NumArgs; ++I) { + ArgPtrs[I] = (void *)((intptr_t)Args[I] + Offsets[I]); + } + + // get the translation table (which contains all the good info). + __tgt_target_table *TargetTable = getOffloadEntriesTable(DeviceId); + // iterate over all the host table entries to see if we can locate the + // host_ptr. + __tgt_offload_entry *Begin = TargetTable->EntriesBegin; + __tgt_offload_entry *End = TargetTable->EntriesEnd; + __tgt_offload_entry *Curr = Begin; + + uint32_t EntryIdx = -1; + + for (uint32_t I = 0; Curr < End; ++Curr, ++I) { + if (Curr->addr != Entry) + continue; + // we got a match, now fill the HostPtrToTableMap so that we + // may avoid this search next time. + DP("[MPI host] Running kernel called %s...\n", Curr->name); + EntryIdx = I; + break; + } + + auto event = EventSystem.createEvent( + DeviceId, NumArgs, ArgPtrs.data(), EntryIdx); + pushNewEvent(event, AsyncInfo); + + return OFFLOAD_SUCCESS; +} + +MPIManagerTy::EventQueue * +MPIManagerTy::getEventQueue(__tgt_async_info *AsyncInfo) { + if (!checkValidAsyncInfo(AsyncInfo)) + return nullptr; + + auto Queue = new EventQueue; + if (AsyncInfo->Queue == nullptr) { + AsyncInfo->Queue = reinterpret_cast(Queue); + } + + return Queue; +} + +void MPIManagerTy::pushNewEvent(const EventPtr &Event, + __tgt_async_info *AsyncInfo) { + auto *Queue = getEventQueue(AsyncInfo); + Queue->push_back(Event); +} + +int32_t MPIManagerTy::synchronize(int32_t DeviceId, + __tgt_async_info *AsyncInfo) { + if (AsyncInfo == nullptr || AsyncInfo->Queue == nullptr) + return OFFLOAD_SUCCESS; + + // Acquire the async context. + EventQueue *Queue = getEventQueue(AsyncInfo); + + int Result = OFFLOAD_SUCCESS; + for (auto &Event : *Queue) { + Event->wait(); + + // Check if the event failed + if (Event->getEventState() == EventStateTy::FAILED) { + REPORT("Event %s has failed during synchronization.\n", + toString(Event->EventType)); + Result = OFFLOAD_FAIL; + break; + } + } + + // Delete the current async_info context. Further use of the same async_info + // object must create a new context. + delete Queue; + AsyncInfo->Queue = nullptr; + + return Result; +} + +bool MPIManagerTy::isInsideDevice() { return !EventSystem.isHead(); } + +void MPIManagerTy::runDeviceMain(__tgt_bin_desc *Desc) { + // Check whether it is a device or not and if so run its initialization + if (EventSystem.isHead()) + return; + + registerLibOnWorker(Desc); + + EventSystem.runGateThread(getOffloadEntriesTableOnWorker()); + + std::exit(EXIT_SUCCESS); +} + +// Synchronization event management +// =========================================================================== +int32_t MPIManagerTy::createEvent(int32_t ID, void **Event) { + if (!checkRecordedEventPtr(Event)) + return OFFLOAD_FAIL; + + auto RecordedEvent = new EventPtr; + if (RecordedEvent == nullptr) { + REPORT("Could not allocate a new synchronization event\n"); + return OFFLOAD_FAIL; + } + + *Event = reinterpret_cast(RecordedEvent); + + return OFFLOAD_SUCCESS; +} + +int32_t MPIManagerTy::destroyEvent(int32_t ID, void *Event) { + if (!checkRecordedEventPtr(Event)) + return OFFLOAD_FAIL; + + delete reinterpret_cast(Event); + + return OFFLOAD_SUCCESS; +} + +int32_t MPIManagerTy::recordEvent(int32_t ID, void *Event, + __tgt_async_info *AsyncInfo) { + if (!checkRecordedEventPtr(Event)) + return OFFLOAD_FAIL; + + if (AsyncInfo == nullptr || AsyncInfo->Queue == nullptr) { + REPORT("Received an invalid async queue on recordEvent\n"); + return OFFLOAD_FAIL; + } + + EventQueue *Queue = getEventQueue(AsyncInfo); + if (Queue->empty()) { + DP("Tried to record an event for an empty event queue\n"); + return OFFLOAD_SUCCESS; + } + + // Copy the last event in the queue to the event handle. + auto &RecordedEvent = *reinterpret_cast(Event); + RecordedEvent = Queue->back(); + + return OFFLOAD_SUCCESS; +} + +int32_t MPIManagerTy::waitEvent(int32_t ID, void *Event, + __tgt_async_info *AsyncInfo) { + if (!checkRecordedEventPtr(Event)) + return OFFLOAD_FAIL; + + if (AsyncInfo == nullptr) { + REPORT("Received an invalid async info on waitEvent\n"); + return OFFLOAD_FAIL; + } + + auto &RecordedEvent = *reinterpret_cast(Event); + if (!RecordedEvent) { + DP("Tried to wait an empty event\n"); + return OFFLOAD_SUCCESS; + } + + // Create a wait event that waits for `Event` to be completed and add it to + // the event queue. This ensures that the whole event queue where `Event` + // originated is completed up to the `Event` itself. Directly waiting on + // `Event` would execute it instead of waiting for its predecessors in its + // original event queue. + EventQueue *Queue = getEventQueue(AsyncInfo); + Queue->push_back(std::make_shared(RecordedEvent)); + return OFFLOAD_SUCCESS; +} + +int32_t MPIManagerTy::syncEvent(int32_t ID, void *Event) { + if (!checkRecordedEventPtr(Event)) + return OFFLOAD_FAIL; + + auto &RecordedEvent = *reinterpret_cast(Event); + if (!RecordedEvent) { + DP("Tried to synchronize an empty event\n"); + return OFFLOAD_SUCCESS; + } + + // Create a wait event that waits for `Event` to be completed and executes it. + // This ensures that the whole event queue where `Event` originated is + // completed up to the `Event` itself. Directly waiting on `Event` would + // execute it instead of waiting for its predecessors in its original event + // queue. + auto WaitEvent = std::make_shared(RecordedEvent); + WaitEvent->wait(); + + return OFFLOAD_SUCCESS; +} diff --git a/openmp/libomptarget/plugins/mpi/src/rtl.cpp b/openmp/libomptarget/plugins/mpi/src/rtl.cpp new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/plugins/mpi/src/rtl.cpp @@ -0,0 +1,249 @@ +//===------RTLs/mpi/src/rtl.cpp - Target RTLs Implementation - C++ ------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// RTL for MPI applications +// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include "Common.h" +#include "MPIManager.h" + +#include "omptargetplugin.h" + +/// Global context that stores device information for the entire binary. +static MPIManagerTy *MPIManager = nullptr; + +/// Helper functions +template +static int32_t syncCall(FuncTy Function, int32_t DeviceId, ArgsTys &&...Args) { + static_assert(std::is_member_function_pointer_v, + "FuncTy should be a member function pointer."); + using IDTy = int32_t; + using AsyncInfoPtrTy = __tgt_async_info *; + static_assert(std::is_invocable_r_v, + "Function is not callable with given arguments."); + + __tgt_async_info LocalAsyncInfo; + int32_t RetCode = + std::invoke(Function, MPIManager, DeviceId, + std::forward(Args)..., &LocalAsyncInfo); + + if (RetCode == OFFLOAD_SUCCESS) + RetCode = MPIManager->synchronize(DeviceId, &LocalAsyncInfo); + + return RetCode; +} + +#ifdef __cplusplus +extern "C" { +#endif + +// De/initialization functions. +// ============================================================================= + +int32_t __tgt_rtl_init_plugin() { + MPIManager = new MPIManagerTy(); + + if (!MPIManager) { + REPORT("Failed to allocate MPI plugin manager.\n"); + return OFFLOAD_FAIL; + } + + if (!MPIManager->initialize()) { + REPORT("Failed to initialize the MPI plugin.\n"); + delete MPIManager; + return OFFLOAD_FAIL; + } + + return OFFLOAD_SUCCESS; +} + +int32_t __tgt_rtl_deinit_plugin() { + if (!MPIManager) { + DP("MPI plugin is not initialized. Deinitialization will do nothing.\n"); + return OFFLOAD_FAIL; + } + + int32_t Ret = OFFLOAD_SUCCESS; + + if (!MPIManager->deinitialize()) { + REPORT("Failed to deinitialize the MPI plugin.\n"); + Ret = OFFLOAD_FAIL; + } + + delete MPIManager; + MPIManager = nullptr; + + return Ret; +} + +int32_t __tgt_rtl_init_device(int32_t DeviceId) { return OFFLOAD_SUCCESS; } + +int32_t __tgt_rtl_deinit_device(int32_t DeviceId) { return OFFLOAD_SUCCESS; } + +// Dynamic library loading and handling. +// ============================================================================= + +int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *Image) { + return MPIManager->isValidBinary(Image); +} + +__tgt_target_table *__tgt_rtl_load_binary(int32_t DeviceId, + __tgt_device_image *Image) { + return MPIManager->loadBinary(DeviceId, Image); +} + +// Plugin and device information. +// ============================================================================= + +int32_t __tgt_rtl_number_of_devices() { return MPIManager->getNumOfDevices(); } + +int32_t __tgt_rtl_is_data_exchangable(int32_t SrcDevId, int32_t DstDevId) { + return MPIManager->isValidDeviceId(SrcDevId) && + MPIManager->isValidDeviceId(DstDevId); +} + +// Plugin and device configuration. +// ============================================================================= + +// TODO: check what this could be used for. +// TODO: Used to control log level +void __tgt_rtl_set_info_flag(uint32_t) { return; } + +// Data management. +// =========================================================================== + +void *__tgt_rtl_data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr, + int32_t Kind) { + return MPIManager->dataAlloc(DeviceId, Size, HostPtr, (TargetAllocTy)Kind); +} + +int32_t __tgt_rtl_data_delete(int32_t DeviceId, void *TargetPtr, int32_t Kind) { + return MPIManager->dataDelete(DeviceId, TargetPtr, (TargetAllocTy)Kind); +} + +int32_t __tgt_rtl_data_submit(int32_t DeviceId, void *TargetPtr, void *HostPtr, + int64_t Size) { + return syncCall(&MPIManagerTy::dataSubmit, DeviceId, TargetPtr, HostPtr, + Size); +} + +int32_t __tgt_rtl_data_submit_async(int32_t DeviceId, void *TargetPtr, + void *HostPtr, int64_t Size, + __tgt_async_info *AsyncInfo) { + return MPIManager->dataSubmit(DeviceId, TargetPtr, HostPtr, Size, AsyncInfo); +} + +int32_t __tgt_rtl_data_retrieve(int32_t DeviceId, void *HostPtr, + void *TargetPtr, int64_t Size) { + return syncCall(&MPIManagerTy::dataRetrieve, DeviceId, HostPtr, TargetPtr, + Size); +} + +int32_t __tgt_rtl_data_retrieve_async(int32_t DeviceId, void *HostPtr, + void *TargetPtr, int64_t Size, + __tgt_async_info *AsyncInfo) { + return MPIManager->dataRetrieve(DeviceId, HostPtr, TargetPtr, Size, + AsyncInfo); +} + +int32_t __tgt_rtl_data_exchange(int32_t SrcId, void *SrcPtr, int32_t DstId, + void *DstPtr, int64_t Size) { + return syncCall(&MPIManagerTy::dataExchange, SrcId, SrcPtr, DstId, DstPtr, + Size); +} + +int32_t __tgt_rtl_data_exchange_async(int32_t SrcId, void *SrcPtr, + int32_t DstId, void *DstPtr, int64_t Size, + __tgt_async_info *AsyncInfo) { + return MPIManager->dataExchange(SrcId, SrcPtr, DstId, DstPtr, Size, + AsyncInfo); +} + +// Target execution. +// =========================================================================== + +int32_t __tgt_rtl_run_target_region(int32_t DeviceId, void *Entry, void **Args, + ptrdiff_t *Offsets, int32_t NumArgs) { + return syncCall(&MPIManagerTy::runTargetRegion, DeviceId, Entry, Args, + Offsets, NumArgs); +} + +int32_t __tgt_rtl_run_target_region_async(int32_t DeviceId, void *Entry, + void **Args, ptrdiff_t *Offsets, + int32_t NumArgs, + __tgt_async_info *AsyncInfo) { + return MPIManager->runTargetRegion(DeviceId, Entry, Args, Offsets, NumArgs, + AsyncInfo); +} + +int32_t __tgt_rtl_run_target_team_region(int32_t DeviceId, void *Entry, + void **Args, ptrdiff_t *Offsets, + int32_t NumArgs, int32_t NumTeams, + int32_t ThreadLimit, + uint64_t LoopTripCount) { + return __tgt_rtl_run_target_region(DeviceId, Entry, Args, Offsets, NumArgs); +} + +int32_t __tgt_rtl_run_target_team_region_async( + int32_t DeviceId, void *Entry, void **Args, ptrdiff_t *Offsets, + int32_t NumArgs, int32_t NumTeams, int32_t ThreadLimit, + uint64_t LoopTripCount, __tgt_async_info *AsyncInfo) { + return __tgt_rtl_run_target_region_async(DeviceId, Entry, Args, Offsets, + NumArgs, AsyncInfo); +} + +// Asynchronous context management. +// =========================================================================== + +int32_t __tgt_rtl_synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfo) { + return MPIManager->synchronize(DeviceId, AsyncInfo); +} + +// External events management. +// =========================================================================== + +int32_t __tgt_rtl_create_event(int32_t DeviceId, void **Event) { + return MPIManager->createEvent(DeviceId, Event); +} + +int32_t __tgt_rtl_record_event(int32_t DeviceId, void *Event, + __tgt_async_info *AsyncInfo) { + return MPIManager->recordEvent(DeviceId, Event, AsyncInfo); +} + +int32_t __tgt_rtl_wait_event(int32_t DeviceId, void *Event, + __tgt_async_info *AsyncInfo) { + return MPIManager->waitEvent(DeviceId, Event, AsyncInfo); +} + +int32_t __tgt_rtl_sync_event(int32_t DeviceId, void *Event) { + return MPIManager->syncEvent(DeviceId, Event); +} + +int32_t __tgt_rtl_destroy_event(int32_t DeviceId, void *Event) { + return MPIManager->destroyEvent(DeviceId, Event); +} + +// Device side operations. +// =========================================================================== + +int32_t __tgt_rtl_is_inside_device() { return MPIManager->isInsideDevice(); } + +void __tgt_rtl_run_device_main(__tgt_bin_desc *Desc) { + MPIManager->runDeviceMain(Desc); +} + +#ifdef __cplusplus +} +#endif diff --git a/openmp/libomptarget/src/interface.cpp b/openmp/libomptarget/src/interface.cpp --- a/openmp/libomptarget/src/interface.cpp +++ b/openmp/libomptarget/src/interface.cpp @@ -41,6 +41,10 @@ } } PM->RTLs.registerLib(Desc); + + // Run device main instead of host main. + for (auto &R : PM->RTLs.ExecutableRTLs) + R->run_device_main(Desc); } //////////////////////////////////////////////////////////////////////////////// diff --git a/openmp/libomptarget/src/rtl.cpp b/openmp/libomptarget/src/rtl.cpp --- a/openmp/libomptarget/src/rtl.cpp +++ b/openmp/libomptarget/src/rtl.cpp @@ -27,6 +27,7 @@ // List of all plugins that can support offloading. static const char *RTLNames[] = { + /* MPI target */ "libomptarget.rtl.mpi.so", /* PowerPC target */ "libomptarget.rtl.ppc64.so", /* x86_64 target */ "libomptarget.rtl.x86_64.so", /* CUDA target */ "libomptarget.rtl.cuda.so", @@ -107,18 +108,6 @@ // Retrieve the RTL information from the runtime library. RTLInfoTy &R = AllRTLs.back(); - // Remove plugin on failure to call optional init_plugin - *((void **)&R.init_plugin) = - DynLibrary->getAddressOfSymbol("__tgt_rtl_init_plugin"); - if (R.init_plugin) { - int32_t Rc = R.init_plugin(); - if (Rc != OFFLOAD_SUCCESS) { - DP("Unable to initialize library '%s': %u!\n", Name, Rc); - AllRTLs.pop_back(); - continue; - } - } - bool ValidPlugin = true; if (!(*((void **)&R.is_valid_binary) = @@ -159,10 +148,45 @@ continue; } + // Remove plugin on failure to call optional init_plugin. Preloading deinit + // for unsupported devices. + *((void **)&R.init_plugin) = + DynLibrary->getAddressOfSymbol("__tgt_rtl_init_plugin"); + *((void **)&R.deinit_plugin) = + DynLibrary->getAddressOfSymbol("__tgt_rtl_deinit_plugin"); + if (R.init_plugin) { + int32_t Rc = R.init_plugin(); + if (Rc != OFFLOAD_SUCCESS) { + DP("Unable to initialize library '%s': %u!\n", Name, Rc); + AllRTLs.pop_back(); + continue; + } + } + + // Check if we are executing code inside the device itself and if we should + // run its main function. + (*((void **)&R.is_inside_device) = + DynLibrary->getAddressOfSymbol("__tgt_rtl_is_inside_device")); + (*((void **)&R.run_device_main) = + DynLibrary->getAddressOfSymbol("__tgt_rtl_run_device_main")); + const bool IsExecutable = + R.is_inside_device && R.is_inside_device() && R.run_device_main; + if (IsExecutable) { + ExecutableRTLs.emplace_back(&R); + } + // No devices are supported by this RTL? - if (!(R.NumberOfDevices = R.number_of_devices())) { + if (!(R.NumberOfDevices = R.number_of_devices()) && !(IsExecutable)) { // The RTL is invalid! Will pop the object from the RTLs list. DP("No devices supported in this RTL\n"); + + // Deinit plugin before removing it. + if (R.deinit_plugin) { + int32_t Rc = R.deinit_plugin(); + if (Rc != OFFLOAD_SUCCESS) + DP("Unable to deinitialize library '%s': %u!\n", Name, Rc); + } + AllRTLs.pop_back(); continue; } @@ -175,8 +199,6 @@ R.NumberOfDevices); // Optional functions - *((void **)&R.deinit_plugin) = - DynLibrary->getAddressOfSymbol("__tgt_rtl_deinit_plugin"); *((void **)&R.is_valid_binary_info) = DynLibrary->getAddressOfSymbol("__tgt_rtl_is_valid_binary_info"); *((void **)&R.deinit_device) = @@ -562,10 +584,10 @@ PM->TblMapMtx.unlock(); // TODO: Write some RTL->unload_image(...) function? - for (auto *R : UsedRTLs) { - if (R->deinit_plugin) { - if (R->deinit_plugin() != OFFLOAD_SUCCESS) { - DP("Failure deinitializing RTL %s!\n", R->RTLName.c_str()); + for (auto &R : AllRTLs) { + if (R.deinit_plugin) { + if (R.deinit_plugin() != OFFLOAD_SUCCESS) { + DP("Failure deinitializing RTL %s!\n", R.RTLName.c_str()); } } } diff --git a/openmp/libomptarget/test/lit.cfg b/openmp/libomptarget/test/lit.cfg --- a/openmp/libomptarget/test/lit.cfg +++ b/openmp/libomptarget/test/lit.cfg @@ -111,6 +111,8 @@ def remove_suffix_if_present(name): if name.endswith('-LTO'): return name[:-4] + elif name.endswith('-mpi'): + return name[:-4] else: return name