diff --git a/openmp/libomptarget/include/OmptCallback.h b/openmp/libomptarget/include/OmptCallback.h --- a/openmp/libomptarget/include/OmptCallback.h +++ b/openmp/libomptarget/include/OmptCallback.h @@ -81,6 +81,9 @@ /// functions to their respective higher layer. void connectLibrary(); +/// OMPT initialization status; false if initializeLibrary has not been executed +extern bool Initialized; + } // namespace ompt } // namespace target } // namespace omp diff --git a/openmp/libomptarget/plugins-nextgen/common/OMPT/OmptCallback.cpp b/openmp/libomptarget/plugins-nextgen/common/OMPT/OmptCallback.cpp --- a/openmp/libomptarget/plugins-nextgen/common/OMPT/OmptCallback.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/OMPT/OmptCallback.cpp @@ -24,6 +24,8 @@ using namespace llvm::omp::target::ompt; +bool llvm::omp::target::ompt::Initialized = false; + ompt_get_callback_t llvm::omp::target::ompt::lookupCallbackByCode = nullptr; ompt_function_lookup_t llvm::omp::target::ompt::lookupCallbackByName = nullptr; @@ -43,6 +45,8 @@ // Store pointer of 'ompt_libomp_target_fn_lookup' for use by the plugin lookupCallbackByName = lookup; + Initialized = true; + return 0; } diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp @@ -406,10 +406,11 @@ OmptInitialized.store(false); // Bind the callbacks to this device's member functions #define bindOmptCallback(Name, Type, Code) \ - if (ompt::lookupCallbackByCode) \ + if (ompt::Initialized && ompt::lookupCallbackByCode) { \ ompt::lookupCallbackByCode((ompt_callbacks_t)(Code), \ ((ompt_callback_t *)&(Name##_fn))); \ - DP("OMPT: class bound %s=%p\n", #Name, ((void *)(uint64_t)Name##_fn)); + DP("OMPT: class bound %s=%p\n", #Name, ((void *)(uint64_t)Name##_fn)); \ + } FOREACH_OMPT_DEVICE_EVENT(bindOmptCallback); #undef bindOmptCallback @@ -422,14 +423,16 @@ return Err; #ifdef OMPT_SUPPORT - bool ExpectedStatus = false; - if (OmptInitialized.compare_exchange_strong(ExpectedStatus, true)) - performOmptCallback(device_initialize, - /* device_num */ DeviceId, - /* type */ getComputeUnitKind().c_str(), - /* device */ reinterpret_cast(this), - /* lookup */ ompt::lookupCallbackByName, - /* documentation */ nullptr); + if (ompt::Initialized) { + bool ExpectedStatus = false; + if (OmptInitialized.compare_exchange_strong(ExpectedStatus, true)) + performOmptCallback(device_initialize, + /* device_num */ DeviceId, + /* type */ getComputeUnitKind().c_str(), + /* device */ reinterpret_cast(this), + /* lookup */ ompt::lookupCallbackByName, + /* documentation */ nullptr); + } #endif // Read and reinitialize the envars that depend on the device initialization. @@ -488,9 +491,11 @@ return Err; #ifdef OMPT_SUPPORT - bool ExpectedStatus = true; - if (OmptInitialized.compare_exchange_strong(ExpectedStatus, false)) - performOmptCallback(device_finalize, /* device_num */ DeviceId); + if (ompt::Initialized) { + bool ExpectedStatus = true; + if (OmptInitialized.compare_exchange_strong(ExpectedStatus, false)) + performOmptCallback(device_finalize, /* device_num */ DeviceId); + } #endif return deinitImpl(); @@ -536,16 +541,19 @@ return std::move(Err); #ifdef OMPT_SUPPORT - size_t Bytes = getPtrDiff(InputTgtImage->ImageEnd, InputTgtImage->ImageStart); - performOmptCallback(device_load, - /* device_num */ DeviceId, - /* FileName */ nullptr, - /* File Offset */ 0, - /* VmaInFile */ nullptr, - /* ImgSize */ Bytes, - /* HostAddr */ InputTgtImage->ImageStart, - /* DeviceAddr */ nullptr, - /* FIXME: ModuleId */ 0); + if (ompt::Initialized) { + size_t Bytes = + getPtrDiff(InputTgtImage->ImageEnd, InputTgtImage->ImageStart); + performOmptCallback(device_load, + /* device_num */ DeviceId, + /* FileName */ nullptr, + /* File Offset */ 0, + /* VmaInFile */ nullptr, + /* ImgSize */ Bytes, + /* HostAddr */ InputTgtImage->ImageStart, + /* DeviceAddr */ nullptr, + /* FIXME: ModuleId */ 0); + } #endif // Return the pointer to the table of entries. diff --git a/openmp/libomptarget/src/OmptCallback.cpp b/openmp/libomptarget/src/OmptCallback.cpp --- a/openmp/libomptarget/src/OmptCallback.cpp +++ b/openmp/libomptarget/src/OmptCallback.cpp @@ -394,6 +394,8 @@ /// Object that will maintain the RTL finalizer from the plugin LibomptargetRtlFinalizer *LibraryFinalizer = nullptr; +bool llvm::omp::target::ompt::Initialized = false; + ompt_get_callback_t llvm::omp::target::ompt::lookupCallbackByCode = nullptr; ompt_function_lookup_t llvm::omp::target::ompt::lookupCallbackByName = nullptr; @@ -421,6 +423,8 @@ LibraryFinalizer = new LibomptargetRtlFinalizer(); + Initialized = true; + return 0; } @@ -463,7 +467,7 @@ /// Used for connecting libomptarget with a plugin void ompt_libomptarget_connect(ompt_start_tool_result_t *result) { DP("Enter ompt_libomptarget_connect\n"); - if (result && LibraryFinalizer) { + if (Initialized && result && LibraryFinalizer) { // Cache each fini function, so that they can be invoked on exit LibraryFinalizer->registerRtl(result->finalize); // Invoke the provided init function with the lookup function maintained