diff --git a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h --- a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h +++ b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h @@ -9,6 +9,10 @@ #include "atmi_runtime.h" #include "hsa.h" #include "hsa_ext_amd.h" +#include "internal.h" + +#include +#include #ifdef __cplusplus extern "C" { @@ -44,11 +48,10 @@ * * @retval ::ATMI_STATUS_UNKNOWN The function encountered errors. */ -atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place, - const char *symbol, - void **var_addr, - unsigned int *var_size); - +atmi_status_t atmi_interop_hsa_get_symbol_info( + const std::map &SymbolInfoTable, + atmi_mem_place_t place, const char *symbol, void **var_addr, + unsigned int *var_size); /** * @brief Get the HSA-specific kernel info from a kernel name * @@ -75,8 +78,10 @@ * @retval ::ATMI_STATUS_UNKNOWN The function encountered errors. */ atmi_status_t atmi_interop_hsa_get_kernel_info( + const std::map &KernelInfoTable, atmi_mem_place_t place, const char *kernel_name, hsa_executable_symbol_info_t info, uint32_t *value); + /** @} */ #ifdef __cplusplus diff --git a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp --- a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp +++ b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp @@ -8,10 +8,10 @@ using core::atl_is_atmi_initialized; -atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place, - const char *symbol, - void **var_addr, - unsigned int *var_size) { +atmi_status_t atmi_interop_hsa_get_symbol_info( + const std::map &SymbolInfoTable, + atmi_mem_place_t place, const char *symbol, void **var_addr, + unsigned int *var_size) { /* // Typical usage: void *var_addr; @@ -32,9 +32,9 @@ // get the symbol info std::string symbolStr = std::string(symbol); - if (SymbolInfoTable[place.dev_id].find(symbolStr) != - SymbolInfoTable[place.dev_id].end()) { - atl_symbol_info_t info = SymbolInfoTable[place.dev_id][symbolStr]; + auto It = SymbolInfoTable.find(symbolStr); + if (It != SymbolInfoTable.end()) { + atl_symbol_info_t info = It->second; *var_addr = reinterpret_cast(info.addr); *var_size = info.size; return ATMI_STATUS_SUCCESS; @@ -46,6 +46,7 @@ } atmi_status_t atmi_interop_hsa_get_kernel_info( + const std::map &KernelInfoTable, atmi_mem_place_t place, const char *kernel_name, hsa_executable_symbol_info_t kernel_info, uint32_t *value) { /* @@ -68,9 +69,9 @@ atmi_status_t status = ATMI_STATUS_SUCCESS; // get the kernel info std::string kernelStr = std::string(kernel_name); - if (KernelInfoTable[place.dev_id].find(kernelStr) != - KernelInfoTable[place.dev_id].end()) { - atl_kernel_info_t info = KernelInfoTable[place.dev_id][kernelStr]; + auto It = KernelInfoTable.find(kernelStr); + if (It != KernelInfoTable.end()) { + atl_kernel_info_t info = It->second; switch (kernel_info) { case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE: *value = info.group_segment_size; diff --git a/openmp/libomptarget/plugins/amdgpu/impl/internal.h b/openmp/libomptarget/plugins/amdgpu/impl/internal.h --- a/openmp/libomptarget/plugins/amdgpu/impl/internal.h +++ b/openmp/libomptarget/plugins/amdgpu/impl/internal.h @@ -106,9 +106,6 @@ uint32_t size; } atl_symbol_info_t; -extern std::vector> KernelInfoTable; -extern std::vector> SymbolInfoTable; - // ---------------------- Kernel End ------------- namespace core { diff --git a/openmp/libomptarget/plugins/amdgpu/impl/system.cpp b/openmp/libomptarget/plugins/amdgpu/impl/system.cpp --- a/openmp/libomptarget/plugins/amdgpu/impl/system.cpp +++ b/openmp/libomptarget/plugins/amdgpu/impl/system.cpp @@ -146,9 +146,6 @@ std::vector atl_gpu_kernarg_pools; -std::vector> KernelInfoTable; -std::vector> SymbolInfoTable; - bool g_atmi_initialized = false; bool g_atmi_hostcall_required = false; @@ -209,15 +206,6 @@ atmi_status_t Runtime::Finalize() { atmi_status_t rc = ATMI_STATUS_SUCCESS; - for (uint32_t i = 0; i < SymbolInfoTable.size(); i++) { - SymbolInfoTable[i].clear(); - } - SymbolInfoTable.clear(); - for (uint32_t i = 0; i < KernelInfoTable.size(); i++) { - KernelInfoTable[i].clear(); - } - KernelInfoTable.clear(); - atl_reset_atmi_initialized(); hsa_status_t err = hsa_shut_down(); if (err != HSA_STATUS_SUCCESS) { @@ -557,13 +545,6 @@ return err; } - int gpu_count = g_atl_machine.processorCount(); - KernelInfoTable.resize(gpu_count); - SymbolInfoTable.resize(gpu_count); - for (uint32_t i = 0; i < SymbolInfoTable.size(); i++) - SymbolInfoTable[i].clear(); - for (uint32_t i = 0; i < KernelInfoTable.size(); i++) - KernelInfoTable[i].clear(); atlc.g_hsa_initialized = true; DEBUG_PRINT("done\n"); } @@ -836,8 +817,9 @@ } } // namespace -static hsa_status_t get_code_object_custom_metadata(void *binary, - size_t binSize, int gpu) { +static hsa_status_t get_code_object_custom_metadata( + void *binary, size_t binSize, int gpu, + std::map &KernelInfoTable) { // parse code object with different keys from v2 // also, the kernel name is not the same as the symbol name -- so a // symbol->name map is needed @@ -1004,14 +986,16 @@ kernel_segment_size, info.kernel_segment_size); // kernel received, now add it to the kernel info table - KernelInfoTable[gpu][kernelName] = info; + KernelInfoTable[kernelName] = info; } return HSA_STATUS_SUCCESS; } -static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol, - int gpu) { +static hsa_status_t +populate_InfoTables(hsa_executable_symbol_t symbol, int gpu, + std::map &KernelInfoTable, + std::map &SymbolInfoTable) { hsa_symbol_kind_t type; uint32_t name_length; @@ -1048,11 +1032,16 @@ // by now, the kernel info table should already have an entry // because the non-ROCr custom code object parsing is called before // iterating over the code object symbols using ROCr - if (KernelInfoTable[gpu].find(kernelName) == KernelInfoTable[gpu].end()) { - return HSA_STATUS_ERROR; + if (KernelInfoTable.find(kernelName) == KernelInfoTable.end()) { + if (HSA_STATUS_ERROR_INVALID_CODE_OBJECT != HSA_STATUS_SUCCESS) { + printf("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, + "Finding the entry kernel info table", + get_error_string(HSA_STATUS_ERROR_INVALID_CODE_OBJECT)); + exit(1); + } } // found, so assign and update - info = KernelInfoTable[gpu][kernelName]; + info = KernelInfoTable[kernelName]; /* Extract dispatch information from the symbol */ err = hsa_executable_symbol_get_info( @@ -1090,7 +1079,7 @@ info.private_segment_size, info.kernel_segment_size); // assign it back to the kernel info table - KernelInfoTable[gpu][kernelName] = info; + KernelInfoTable[kernelName] = info; free(name); } else if (type == HSA_SYMBOL_KIND_VARIABLE) { err = hsa_executable_symbol_get_info( @@ -1136,7 +1125,7 @@ if (err != HSA_STATUS_SUCCESS) { return err; } - SymbolInfoTable[gpu][std::string(name)] = info; + SymbolInfoTable[std::string(name)] = info; if (strcmp(name, "needs_hostcall_buffer") == 0) g_atmi_hostcall_required = true; free(name); @@ -1146,7 +1135,9 @@ return HSA_STATUS_SUCCESS; } -atmi_status_t Runtime::RegisterModuleFromMemory( +atmi_status_t RegisterModuleFromMemory( + std::map &KernelInfoTable, + std::map &SymbolInfoTable, void *module_bytes, size_t module_size, atmi_place_t place, atmi_status_t (*on_deserialized_data)(void *data, size_t size, void *cb_state), @@ -1186,7 +1177,8 @@ // Some metadata info is not available through ROCr API, so use custom // code object metadata parsing to collect such metadata info - err = get_code_object_custom_metadata(module_bytes, module_size, gpu); + err = get_code_object_custom_metadata(module_bytes, module_size, gpu, + KernelInfoTable); if (err != HSA_STATUS_SUCCESS) { DEBUG_PRINT("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, "Getting custom code object metadata", @@ -1243,9 +1235,9 @@ err = hsa::executable_iterate_symbols( executable, [&](hsa_executable_t, hsa_executable_symbol_t symbol) -> hsa_status_t { - return populate_InfoTables(symbol, gpu); + return populate_InfoTables(symbol, gpu, KernelInfoTable, + SymbolInfoTable); }); - if (err != HSA_STATUS_SUCCESS) { printf("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, "Iterating over symbols for execuatable", get_error_string(err)); diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp --- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -86,6 +86,16 @@ #include "elf_common.h" +namespace core { +atmi_status_t RegisterModuleFromMemory( + std::map &KernelInfo, + std::map &SymbolInfoTable, void *, size_t, + atmi_place_t, + atmi_status_t (*on_deserialized_data)(void *data, size_t size, + void *cb_state), + void *cb_state, std::vector &HSAExecutables); +} + /// Keep entries table per device struct FuncOrGblEntryTy { __tgt_target_table Table; @@ -337,6 +347,9 @@ std::vector HSAExecutables; + std::vector> KernelInfoTable; + std::vector> SymbolInfoTable; + struct atmiFreePtrDeletor { void operator()(void *p) { atmi_free(p); // ignore failure to free @@ -480,6 +493,8 @@ NumTeams.resize(NumberOfDevices); NumThreads.resize(NumberOfDevices); deviceStateStore.resize(NumberOfDevices); + KernelInfoTable.resize(NumberOfDevices); + SymbolInfoTable.resize(NumberOfDevices); for (int i = 0; i < NumberOfDevices; i++) { HSAQueues[i] = nullptr; @@ -991,15 +1006,17 @@ template atmi_status_t module_register_from_memory_to_place( + std::map &KernelInfoTable, + std::map &SymbolInfoTable, void *module_bytes, size_t module_size, atmi_place_t place, C cb, std::vector &HSAExecutables) { auto L = [](void *data, size_t size, void *cb_state) -> atmi_status_t { C *unwrapped = static_cast(cb_state); return (*unwrapped)(data, size); }; - return core::Runtime::RegisterModuleFromMemory( - module_bytes, module_size, place, L, static_cast(&cb), - HSAExecutables); + return core::RegisterModuleFromMemory( + KernelInfoTable, SymbolInfoTable, module_bytes, module_size, place, L, + static_cast(&cb), HSAExecutables); } } // namespace @@ -1114,11 +1131,12 @@ DP("Setting global device environment after load (%u bytes)\n", si.size); int device_id = host_device_env.device_num; - + auto &SymbolInfo = DeviceInfo.SymbolInfoTable[device_id]; void *state_ptr; uint32_t state_ptr_size; atmi_status_t err = atmi_interop_hsa_get_symbol_info( - get_gpu_mem_place(device_id), sym(), &state_ptr, &state_ptr_size); + SymbolInfo, get_gpu_mem_place(device_id), sym(), &state_ptr, + &state_ptr_size); if (err != ATMI_STATUS_SUCCESS) { DP("failed to find %s in loaded image\n", sym()); return err; @@ -1197,8 +1215,11 @@ auto env = device_environment(device_id, DeviceInfo.NumberOfDevices, image, img_size); + auto &KernelInfo = DeviceInfo.KernelInfoTable[device_id]; + auto &SymbolInfo = DeviceInfo.SymbolInfoTable[device_id]; atmi_status_t err = module_register_from_memory_to_place( - (void *)image->ImageStart, img_size, get_gpu_place(device_id), + KernelInfo, SymbolInfo, (void *)image->ImageStart, img_size, + get_gpu_place(device_id), [&](void *data, size_t size) { return env.before_loading(data, size); }, DeviceInfo.HSAExecutables); @@ -1227,9 +1248,10 @@ void *state_ptr; uint32_t state_ptr_size; + auto &SymbolInfoMap = DeviceInfo.SymbolInfoTable[device_id]; atmi_status_t err = atmi_interop_hsa_get_symbol_info( - get_gpu_mem_place(device_id), "omptarget_nvptx_device_State", - &state_ptr, &state_ptr_size); + SymbolInfoMap, get_gpu_mem_place(device_id), + "omptarget_nvptx_device_State", &state_ptr, &state_ptr_size); if (err != ATMI_STATUS_SUCCESS) { DP("No device_state symbol found, skipping initialization\n"); @@ -1311,8 +1333,10 @@ void *varptr; uint32_t varsize; + auto &SymbolInfoMap = DeviceInfo.SymbolInfoTable[device_id]; atmi_status_t err = atmi_interop_hsa_get_symbol_info( - get_gpu_mem_place(device_id), e->name, &varptr, &varsize); + SymbolInfoMap, get_gpu_mem_place(device_id), e->name, &varptr, + &varsize); if (err != ATMI_STATUS_SUCCESS) { // Inform the user what symbol prevented offloading @@ -1353,8 +1377,10 @@ atmi_mem_place_t place = get_gpu_mem_place(device_id); uint32_t kernarg_segment_size; + auto &KernelInfoMap = DeviceInfo.KernelInfoTable[device_id]; atmi_status_t err = atmi_interop_hsa_get_kernel_info( - place, e->name, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, + KernelInfoMap, place, e->name, + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, &kernarg_segment_size); // each arg is a void * in this openmp implementation @@ -1782,6 +1808,7 @@ KernelTy *KernelInfo = (KernelTy *)tgt_entry_ptr; std::string kernel_name = std::string(KernelInfo->Name); + auto &KernelInfoTable = DeviceInfo.KernelInfoTable; if (KernelInfoTable[device_id].find(kernel_name) == KernelInfoTable[device_id].end()) { DP("Kernel %s not found\n", kernel_name.c_str());