diff --git a/openmp/libomptarget/include/omptargetplugin.h b/openmp/libomptarget/include/omptargetplugin.h --- a/openmp/libomptarget/include/omptargetplugin.h +++ b/openmp/libomptarget/include/omptargetplugin.h @@ -133,6 +133,11 @@ // error code. int32_t __tgt_rtl_synchronize(int32_t ID, __tgt_async_info *AsyncInfoPtr); +/// Return non-zero if the target device supports memory pool in its driver +/// library such that the host side memory pool is not needed. Otherwise, return +/// 0. +int32_t __tgt_rtl_is_memory_pool_supported(); + #ifdef __cplusplus } #endif diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -1167,6 +1167,8 @@ return DeviceRTL.synchronize(device_id, async_info_ptr); } +int32_t __tgt_rtl_is_memory_pool_supported() { return 0; } + #ifdef __cplusplus } #endif diff --git a/openmp/libomptarget/plugins/exports b/openmp/libomptarget/plugins/exports --- a/openmp/libomptarget/plugins/exports +++ b/openmp/libomptarget/plugins/exports @@ -19,6 +19,7 @@ __tgt_rtl_run_target_region; __tgt_rtl_run_target_region_async; __tgt_rtl_synchronize; + __tgt_rtl_is_memory_pool_supported; local: *; }; diff --git a/openmp/libomptarget/src/CMakeLists.txt b/openmp/libomptarget/src/CMakeLists.txt --- a/openmp/libomptarget/src/CMakeLists.txt +++ b/openmp/libomptarget/src/CMakeLists.txt @@ -1,9 +1,9 @@ ##===----------------------------------------------------------------------===## -# +# # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# +# ##===----------------------------------------------------------------------===## # # Build offloading library libomptarget.so. @@ -16,6 +16,7 @@ api.cpp device.cpp interface.cpp + memory.cpp rtl.cpp omptarget.cpp ) diff --git a/openmp/libomptarget/src/device.h b/openmp/libomptarget/src/device.h --- a/openmp/libomptarget/src/device.h +++ b/openmp/libomptarget/src/device.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -26,6 +27,9 @@ struct __tgt_bin_desc; struct __tgt_target_table; struct __tgt_async_info; +namespace memory { +class MemoryManagerTy; +} // namespace memory /// Map between host data and target data. struct HostDataToTargetTy { @@ -142,10 +146,14 @@ // moved into the target task in libomp. std::map LoopTripCnt; + /// Memory manager + std::shared_ptr MemoryManager; + DeviceTy(RTLInfoTy *RTL) : DeviceID(-1), RTL(RTL), RTLDeviceID(-1), IsInit(false), InitFlag(), HasPendingGlobals(false), HostDataToTargetMap(), PendingCtorsDtors(), - ShadowPtrMap(), DataMapMtx(), PendingGlobalsMtx(), ShadowMtx() {} + ShadowPtrMap(), DataMapMtx(), PendingGlobalsMtx(), ShadowMtx(), + MemoryManager(nullptr) {} // The existence of mutexes makes DeviceTy non-copyable. We need to // provide a copy constructor and an assignment operator explicitly. diff --git a/openmp/libomptarget/src/device.cpp b/openmp/libomptarget/src/device.cpp --- a/openmp/libomptarget/src/device.cpp +++ b/openmp/libomptarget/src/device.cpp @@ -11,11 +11,13 @@ //===----------------------------------------------------------------------===// #include "device.h" +#include "memory.h" #include "private.h" #include "rtl.h" #include #include +#include #include /// Map between Device ID (i.e. openmp device id) and its DeviceTy. @@ -321,10 +323,25 @@ // Make call to init_requires if it exists for this plugin. if (RTL->init_requires) RTL->init_requires(RTLs->RequiresFlags); - int32_t rc = RTL->init_device(RTLDeviceID); - if (rc == OFFLOAD_SUCCESS) { - IsInit = true; - } + int32_t Ret = RTL->init_device(RTLDeviceID); + if (Ret != OFFLOAD_SUCCESS) + return; + + size_t Threshold = 1U << 13; + + if (const char *Env = std::getenv("LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD")) + Threshold = std::stoul(Env); + + // Only when the following conditions are met can we use memory manager: + // 1. Threashold is not set to 0 by user via env + // 2. Device does not implement memory pool on their side. If the plugin does + // not implement the interface function \p is_memory_pool_supported, we assume + // the device does not support it. + if (Threshold && + (!RTL->is_memory_pool_supported || RTL->is_memory_pool_supported() == 0)) + MemoryManager = std::make_shared(*this); + + IsInit = true; } /// Thread-safe method to initialize the device only once. @@ -352,10 +369,18 @@ } void *DeviceTy::allocData(int64_t Size, void *HstPtr) { + // If memory manager is enabled, we will allocate data via memory manager. + if (MemoryManager) + return MemoryManager->allocate(Size, HstPtr); + return RTL->data_alloc(RTLDeviceID, Size, HstPtr); } int32_t DeviceTy::deleteData(void *TgtPtrBegin) { + // If memory manager is enabled, we will deallocate data via memory manager. + if (MemoryManager) + return MemoryManager->free(TgtPtrBegin); + return RTL->data_delete(RTLDeviceID, TgtPtrBegin); } diff --git a/openmp/libomptarget/src/memory.h b/openmp/libomptarget/src/memory.h new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/src/memory.h @@ -0,0 +1,40 @@ +//===----------- memory.h - Target independent memory manager -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Declarations for target independent memory manager. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +// Forward declaration +struct DeviceTy; + +namespace memory { +namespace impl { +class MemoryManagerTy; +} // namespace impl + +class MemoryManagerTy { + std::shared_ptr Impl; + +public: + /// Constructor + MemoryManagerTy(DeviceTy &D, size_t Threshold = 0); + + /// Allocate memory of size \p Size from target device. \p HstPtr is used to + /// assist the allocation. + void *allocate(size_t Size, void *HstPtr); + + /// Deallocate memory pointed by \p TgtPtr + int free(void *TgtPtr); +}; +} // namespace memory diff --git a/openmp/libomptarget/src/memory.cpp b/openmp/libomptarget/src/memory.cpp new file mode 100644 --- /dev/null +++ b/openmp/libomptarget/src/memory.cpp @@ -0,0 +1,251 @@ +//===----------- memory.cpp - Target independent memory manager -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Functionality for managing target memory. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include + +#include "device.h" +#include "memory.h" +#include "rtl.h" + +namespace memory { +namespace impl { +constexpr const size_t BucketSize[] = { + 0, 1U << 2, 1U << 3, 1U << 4, 1U << 5, 1U << 6, 1U << 7, + 1U << 8, 1U << 9, 1U << 10, 1U << 11, 1U << 12, 1U << 13}; + +constexpr const int NumBuckets = sizeof(BucketSize) / sizeof(size_t); + +/// The threshold to manage memory using memory manager +size_t SizeThreshold = BucketSize[NumBuckets - 1]; + +/// Find the previous number that is power of 2 given a number that is not power +/// of 2. +inline size_t flp2(size_t Num) { + Num |= Num >> 1; + Num |= Num >> 2; + Num |= Num >> 4; + Num |= Num >> 8; + Num |= Num >> 16; + Num |= Num >> 32; + Num += 1; + return Num >> 1; +} + +/// Find a suitable bucket +inline int findBucket(size_t Size) { + const size_t F = flp2(Size); + int L = 0, H = NumBuckets - 1; + while (H - L > 1) { + int M = (L + H) >> 1; + if (BucketSize[M] == F) + return M; + if (BucketSize[M] > F) + H = M - 1; + else + L = M; + } + + assert(L >= 0 && L < NumBuckets && "L is out of range"); + + return L; +} + +struct NodeTy { + /// Memory size + size_t Size; + /// Target pointer + void *Ptr; + + /// Constructor + NodeTy(size_t Size, void *Ptr) : Size(Size), Ptr(Ptr) {} + + /// To make Nodes ordered when they're put into \p std::multiset. + bool operator<(const NodeTy &RHS) { return Size < RHS.Size; } +}; + +using NodePtrTy = std::shared_ptr; + +class MemoryManagerTy { + /// A \p FreeList is a set of Nodes. We're using \p std::multiset here to make + /// the \p find procedure more efficient. + using FreeListTy = std::multiset; + + /// A list of \p FreeListTy entries, each of which is a multiset of Nodes + /// whose size is less or equal to a specific bucket size. + std::vector FreeLists; + /// A table to map from a target pointer to its node + std::unordered_map PtrToNodeTable; + /// A list of mutex for each \p FreeListTy entry + std::vector FreeListLocks; + /// The mutex for the table \p PtrToNodeTable + std::mutex MapTableLock; + /// A reference to its corresponding \p DeviceTy object + DeviceTy &Device; + + /// Request memory from target device + void *allocateFromDevice(size_t Size, void *HstPtr) const { + return Device.allocData(Size, HstPtr); + } + + /// Deallocate data from device + int deleteFromDevice(void *Ptr) const { return Device.deleteData(Ptr); } + + /// This function is called when it tries to allocate memory on device but the + /// device returns out of memory. It will first free all memory in the + /// FreeList and try to allocate again. + void *freeAndAllocate(size_t Size, void *HstPtr) { + // Deallocate all memory in FreeList + for (int I = 0; I < NumBuckets; ++I) { + FreeListTy &List = FreeLists[I]; + std::lock_guard Lock(FreeListLocks[I]); + for (const NodePtrTy &N : List) + deleteFromDevice(N->Ptr); + FreeLists[I].clear(); + } + + // Try allocate memory again + return allocateFromDevice(Size, HstPtr); + } + +public: + /// Constructor + MemoryManagerTy(DeviceTy &Dev) + : FreeLists(NumBuckets), FreeListLocks(NumBuckets), Device(Dev) {} + + /// Destructor + ~MemoryManagerTy() { + // TODO: There is a little issue that target plugin is destroyed before this + // object, therefore the memory free will not succeed. + // Deallocate all memory in FreeList + for (int I = 0; I < NumBuckets; ++I) { + // We don't need lock here because only one thread can execute it + FreeListTy &List = FreeLists[I]; + for (const NodePtrTy &N : List) + deleteFromDevice(N->Ptr); + } + + // Deallocate all memory in map + for (std::pair P : PtrToNodeTable) { + assert(P.second->Ptr && "nullptr in map table"); + deleteFromDevice(P.second->Ptr); + } + } + + void *allocate(size_t Size, void *HstPtr) { + // If the size is zero, we will not bother the target device. Just return + // nullptr directly. + if (Size == 0) + return nullptr; + + // If the size is greater than the threshold, allocate it directly from + // device. + if (Size > SizeThreshold) { + void *TgtPtr = allocateFromDevice(Size, HstPtr); + // We cannot get memory from the device. It might be due to OOM. Let's + // free all memory in FreeLists and try again. + if (TgtPtr == nullptr) + return freeAndAllocate(Size, HstPtr); + return TgtPtr; + } + + NodePtrTy NodePtr(nullptr); + + // Try to get a node from FreeList + { + const int B = findBucket(Size); + FreeListTy &List = FreeLists[B]; + std::lock_guard LG(FreeListLocks[B]); + FreeListTy::iterator Itr = + std::find_if(List.begin(), List.end(), + [Size](const NodePtrTy &I) { return I->Size == Size; }); + + if (Itr != List.end()) { + NodePtr = *Itr; + List.erase(Itr); + } + } + + // We cannot find a valid node in FreeLists. Let's allocate from device and + // create a node for it. + if (NodePtr == nullptr) { + // Allocate one from device + void *TgtPtr = allocateFromDevice(Size, HstPtr); + + // If TgtPtr is nullptr, it might be due to OOM. Call freeAndAllocate to + // free some memory in FreeList and then allocate again + if (TgtPtr == nullptr) + TgtPtr = freeAndAllocate(Size, HstPtr); + + // We still cannot get memory from device. Return nullptr. + if (TgtPtr == nullptr) + return nullptr; + + NodePtr = std::make_shared(Size, TgtPtr); + } + + // Insert the node into the map table + { + std::lock_guard Guard(MapTableLock); + PtrToNodeTable[NodePtr->Ptr] = NodePtr; + } + + return NodePtr->Ptr; + } + + int free(void *TgtPtr) { + NodePtrTy P(nullptr); + + // Look it up into the table + { + std::lock_guard G(MapTableLock); + std::unordered_map::iterator Itr = + PtrToNodeTable.find(TgtPtr); + + // Remove this item from the map table if it is managed by memory manager + if (Itr != PtrToNodeTable.end()) { + P = Itr->second; + PtrToNodeTable.erase(Itr); + } + } + + // Insert the node to the free list + if (P) { + const int B = findBucket(P->Size); + FreeListTy &List = FreeLists[B]; + std::lock_guard G(FreeListLocks[B]); + List.insert(P); + } + + // Delete the memory from device + return deleteFromDevice(TgtPtr); + } +}; +} // namespace impl + +void *MemoryManagerTy::allocate(size_t Size, void *HstPtr) { + return Impl->allocate(Size, HstPtr); +} + +int MemoryManagerTy::free(void *TgtPtr) { return Impl->free(TgtPtr); } + +MemoryManagerTy::MemoryManagerTy(DeviceTy &D, size_t Threshold) + : Impl(new impl::MemoryManagerTy(D)) { + if (Threshold) + impl::SizeThreshold = Threshold; +} +} // namespace memory diff --git a/openmp/libomptarget/src/omptarget.cpp b/openmp/libomptarget/src/omptarget.cpp --- a/openmp/libomptarget/src/omptarget.cpp +++ b/openmp/libomptarget/src/omptarget.cpp @@ -14,6 +14,7 @@ #include #include "device.h" +#include "memory.h" #include "private.h" #include "rtl.h" diff --git a/openmp/libomptarget/src/rtl.h b/openmp/libomptarget/src/rtl.h --- a/openmp/libomptarget/src/rtl.h +++ b/openmp/libomptarget/src/rtl.h @@ -53,6 +53,7 @@ __tgt_async_info *); typedef int64_t(init_requires_ty)(int64_t); typedef int64_t(synchronize_ty)(int32_t, __tgt_async_info *); + typedef int64_t(is_memory_pool_supported_ty)(); int32_t Idx = -1; // RTL index, index is the number of devices // of other RTLs that were registered before, @@ -86,6 +87,7 @@ run_team_region_async_ty *run_team_region_async = nullptr; init_requires_ty *init_requires = nullptr; synchronize_ty *synchronize = nullptr; + is_memory_pool_supported_ty *is_memory_pool_supported = nullptr; // Are there images associated with this RTL. bool isUsed = false; @@ -126,6 +128,7 @@ init_requires = r.init_requires; isUsed = r.isUsed; synchronize = r.synchronize; + is_memory_pool_supported = r.is_memory_pool_supported; } }; diff --git a/openmp/libomptarget/src/rtl.cpp b/openmp/libomptarget/src/rtl.cpp --- a/openmp/libomptarget/src/rtl.cpp +++ b/openmp/libomptarget/src/rtl.cpp @@ -10,9 +10,10 @@ // //===----------------------------------------------------------------------===// +#include "rtl.h" #include "device.h" +#include "memory.h" #include "private.h" -#include "rtl.h" #include #include @@ -146,6 +147,8 @@ dlsym(dynlib_handle, "__tgt_rtl_data_exchange_async"); *((void **)&R.is_data_exchangable) = dlsym(dynlib_handle, "__tgt_rtl_is_data_exchangable"); + *((void **)&R.is_memory_pool_supported) = + dlsym(dynlib_handle, "__tgt_rtl_is_memory_pool_supported"); // No devices are supported by this RTL? if (!(R.NumberOfDevices = R.number_of_devices())) {