diff --git a/openmp/libomptarget/include/omptargetplugin.h b/openmp/libomptarget/include/omptargetplugin.h --- a/openmp/libomptarget/include/omptargetplugin.h +++ b/openmp/libomptarget/include/omptargetplugin.h @@ -133,6 +133,11 @@ // error code. int32_t __tgt_rtl_synchronize(int32_t ID, __tgt_async_info *AsyncInfoPtr); +// Register a callback function which will be invoked when the plugin is +// destructed. It accepts an argument of type void * which might be used in the +// callback function. +void __tgt_rtl_register_destruction_cb(void (*cb)(void *), void *data); + #ifdef __cplusplus } #endif diff --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp --- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp @@ -290,6 +290,15 @@ std::vector DeviceData; std::vector Modules; + // Struct to store call back inforation that will be used the plugin is + // destructed. + struct DestructionCBInfoTy { + void (*CB)(void *); + void *Data; + }; + + DestructionCBInfoTy DestructionCBInfo; + // Record entry point associated with device void addOffloadEntry(const int DeviceId, const __tgt_offload_entry entry) { FuncOrGblEntryTy &E = DeviceData[DeviceId].FuncGblEntries.back(); @@ -382,6 +391,10 @@ } ~DeviceRTLTy() { + // Call the callback function if it exists + if (DestructionCBInfo.CB) + DestructionCBInfo.CB(DestructionCBInfo.Data); + // First destruct stream manager in case of Contexts is destructed before it StreamManager = nullptr; @@ -993,6 +1006,11 @@ return OFFLOAD_SUCCESS; } + + void registerDestructionCB(void (*CB)(void *), void *Data) { + DestructionCBInfo.CB = CB; + DestructionCBInfo.Data = Data; + } }; DeviceRTLTy DeviceRTL; @@ -1187,6 +1205,10 @@ return DeviceRTL.synchronize(device_id, async_info_ptr); } +void __tgt_rtl_register_destruction_cb(void (*cb)(void *), void *data) { + DeviceRTL.registerDestructionCB(cb, data); +} + #ifdef __cplusplus } #endif diff --git a/openmp/libomptarget/plugins/exports b/openmp/libomptarget/plugins/exports --- a/openmp/libomptarget/plugins/exports +++ b/openmp/libomptarget/plugins/exports @@ -19,6 +19,7 @@ __tgt_rtl_run_target_region; __tgt_rtl_run_target_region_async; __tgt_rtl_synchronize; + __tgt_rtl_register_destruction_cb; local: *; }; diff --git a/openmp/libomptarget/src/MemoryManager.cpp b/openmp/libomptarget/src/MemoryManager.cpp --- a/openmp/libomptarget/src/MemoryManager.cpp +++ b/openmp/libomptarget/src/MemoryManager.cpp @@ -92,9 +92,6 @@ } MemoryManagerTy::~MemoryManagerTy() { - // TODO: There is a little issue that target plugin is destroyed before this - // object, therefore the memory free will not succeed. - // Deallocate all memory in map for (auto Itr = PtrToNodeTable.begin(); Itr != PtrToNodeTable.end(); ++Itr) { assert(Itr->second.Ptr && "nullptr in map table"); deleteOnDevice(Itr->second.Ptr); diff --git a/openmp/libomptarget/src/rtl.h b/openmp/libomptarget/src/rtl.h --- a/openmp/libomptarget/src/rtl.h +++ b/openmp/libomptarget/src/rtl.h @@ -24,6 +24,12 @@ struct DeviceTy; struct __tgt_bin_desc; +extern "C" { +// Do some cleanup stuffs for the RTLInfoTy because the plugin dylib is being +// destructed so that any later use of plugin should be avoided. +void cleanupRTL(void *); +} + struct RTLInfoTy { typedef int32_t(is_valid_binary_ty)(void *); typedef int32_t(is_data_exchangable_ty)(int32_t, int32_t); @@ -53,6 +59,7 @@ __tgt_async_info *); typedef int64_t(init_requires_ty)(int64_t); typedef int64_t(synchronize_ty)(int32_t, __tgt_async_info *); + typedef void(register_destruction_cb_ty)(void (*)(void *), void *); int32_t Idx = -1; // RTL index, index is the number of devices // of other RTLs that were registered before, @@ -86,6 +93,7 @@ run_team_region_async_ty *run_team_region_async = nullptr; init_requires_ty *init_requires = nullptr; synchronize_ty *synchronize = nullptr; + register_destruction_cb_ty *register_destruction_cb = nullptr; // Are there images associated with this RTL. bool isUsed = false; @@ -94,6 +102,9 @@ // It is easier to enforce thread-safety at the libomptarget level, // so that developers of new RTLs do not have to worry about it. std::mutex Mtx; + + // Vector of pointers of DeviceTy using this RTLInfoTy + std::vector Devices; }; /// RTLs identified in the system. diff --git a/openmp/libomptarget/src/rtl.cpp b/openmp/libomptarget/src/rtl.cpp --- a/openmp/libomptarget/src/rtl.cpp +++ b/openmp/libomptarget/src/rtl.cpp @@ -10,9 +10,10 @@ // //===----------------------------------------------------------------------===// +#include "rtl.h" +#include "MemoryManager.h" #include "device.h" #include "private.h" -#include "rtl.h" #include #include @@ -168,6 +169,11 @@ dlsym(dynlib_handle, "__tgt_rtl_data_exchange_async"); *((void **)&R.is_data_exchangable) = dlsym(dynlib_handle, "__tgt_rtl_is_data_exchangable"); + *((void **)&R.register_destruction_cb) = + dlsym(dynlib_handle, "__tgt_rtl_register_destruction_cb"); + + if (R.register_destruction_cb) + R.register_destruction_cb(cleanupRTL, &R); } DP("RTLs loaded!\n"); @@ -303,15 +309,16 @@ // If this RTL is not already in use, initialize it. if (!R.isUsed) { // Initialize the device information for the RTL we are about to use. - DeviceTy device(&R); size_t Start = PM->Devices.size(); - PM->Devices.resize(Start + R.NumberOfDevices, device); - for (int32_t device_id = 0; device_id < R.NumberOfDevices; - device_id++) { + PM->Devices.resize(Start + R.NumberOfDevices, DeviceTy(&R)); + for (int32_t DeviceID = 0; DeviceID < R.NumberOfDevices; DeviceID++) { + DeviceTy &D = PM->Devices[Start + DeviceID]; // global device ID - PM->Devices[Start + device_id].DeviceID = Start + device_id; + D.DeviceID = Start + DeviceID; // RTL local device ID - PM->Devices[Start + device_id].RTLDeviceID = device_id; + D.RTLDeviceID = DeviceID; + // Save the DeviceTy pointer to RTLInfoTy + R.Devices.push_back(&D); } // Initialize the index of this RTL and save it in the used RTLs. @@ -448,3 +455,13 @@ DP("Done unregistering library!\n"); } + +extern "C" { +void cleanupRTL(void *Data) { + DP("Doing cleanup for RTLInfoTy " DPxMOD ".\n", DPxPTR(Data)); + + // Free all memory allocated to each device + for (DeviceTy *D : reinterpret_cast(Data)->Devices) + D->MemoryManager.release(); +} +}