diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -262,13 +262,11 @@ public: ResourcePoolTy(AllocatorTy &&A, size_t Size = 0) noexcept : Allocator(std::move(A)) { - (void)resize(Size); + if (Size) + (void)resize(Size); } - ~ResourcePoolTy() noexcept { - for (auto &R : Resources) - (void)Allocator.destroy(R); - } + ~ResourcePoolTy() noexcept { clear(); } /// Get a resource from pool. `Next` always points to the next available /// resource. That means, `[0, next-1]` have been assigned, and `[id,]` are @@ -283,8 +281,13 @@ /// Next int acquire(T &R) noexcept { std::lock_guard LG(Mutex); - if (Next == Resources.size() && !resize(Resources.size() * 2)) - return OFFLOAD_FAIL; + if (Next == Resources.size()) { + auto NewSize = Resources.size() ? Resources.size() * 2 : 1; + if (!resize(NewSize)) + return OFFLOAD_FAIL; + } + + assert(Next < Resources.size()); R = Resources[Next++]; @@ -307,6 +310,14 @@ std::lock_guard LG(Mutex); Resources[--Next] = R; } + + /// Released all stored resources and clear the pool. + /// Note: This function is not thread safe. Be sure to guard it if necessary. + void clear() noexcept { + for (auto &R : Resources) + (void)Allocator.destroy(R); + Resources.clear(); + } }; class DeviceRTLTy { @@ -328,7 +339,6 @@ static constexpr const int DefaultNumThreads = 128; using StreamPoolTy = ResourcePoolTy; - using StreamAllocatorTy = AllocatorTy; std::vector> StreamPool; std::vector DeviceData; @@ -563,7 +573,7 @@ checkResult(cuModuleUnload(M), "Error returned from cuModuleUnload\n"); for (auto &S : StreamPool) - S = nullptr; + S.reset(); for (DeviceDataTy &D : DeviceData) { // Destroy context @@ -631,7 +641,8 @@ // Initialize stream pool if (!StreamPool[DeviceId]) StreamPool[DeviceId] = std::make_unique( - StreamAllocatorTy(DeviceData[DeviceId].Context), NumInitialStreams); + AllocatorTy(DeviceData[DeviceId].Context), + NumInitialStreams); // Query attributes to determine number of threads/block and blocks/grid. int MaxGridDimX;