diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp --- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -304,49 +304,6 @@ return header; } -hsa_status_t addKernArgPool(hsa_amd_memory_pool_t MemoryPool, void *Data) { - std::vector *Result = - static_cast *>(Data); - bool AllocAllowed = false; - hsa_status_t err = hsa_amd_memory_pool_get_info( - MemoryPool, HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED, - &AllocAllowed); - if (err != HSA_STATUS_SUCCESS) { - DP("Alloc allowed in memory pool check failed: %s\n", - get_error_string(err)); - return err; - } - - if (!AllocAllowed) { - // nothing needs to be done here. - return HSA_STATUS_SUCCESS; - } - - uint32_t GlobalFlags = 0; - err = hsa_amd_memory_pool_get_info( - MemoryPool, HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS, &GlobalFlags); - if (err != HSA_STATUS_SUCCESS) { - DP("Get memory pool info failed: %s\n", get_error_string(err)); - return err; - } - - size_t size = 0; - err = hsa_amd_memory_pool_get_info(MemoryPool, HSA_AMD_MEMORY_POOL_INFO_SIZE, - &size); - if (err != HSA_STATUS_SUCCESS) { - DP("Get memory pool size failed: %s\n", get_error_string(err)); - return err; - } - - if ((GlobalFlags & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED) && - (GlobalFlags & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT) && - size > 0) { - Result->push_back(MemoryPool); - } - - return HSA_STATUS_SUCCESS; -} - std::pair isValidMemoryPool(hsa_amd_memory_pool_t MemoryPool) { bool AllocAllowed = false; @@ -362,55 +319,6 @@ return {HSA_STATUS_SUCCESS, AllocAllowed}; } -template -hsa_status_t collectMemoryPools(const std::vector &Agents, - AccumulatorFunc Func) { - for (int DeviceId = 0; DeviceId < Agents.size(); DeviceId++) { - hsa_status_t Err = hsa::amd_agent_iterate_memory_pools( - Agents[DeviceId], [&](hsa_amd_memory_pool_t MemoryPool) { - hsa_status_t Err; - bool Valid = false; - std::tie(Err, Valid) = isValidMemoryPool(MemoryPool); - if (Err != HSA_STATUS_SUCCESS) { - return Err; - } - if (Valid) - Func(MemoryPool, DeviceId); - return HSA_STATUS_SUCCESS; - }); - - if (Err != HSA_STATUS_SUCCESS) { - DP("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, - "Iterate all memory pools", get_error_string(Err)); - return Err; - } - } - - return HSA_STATUS_SUCCESS; -} - -std::pair -FindKernargPool(const std::vector &HSAAgents) { - std::vector KernArgPools; - for (const auto &Agent : HSAAgents) { - hsa_status_t err = HSA_STATUS_SUCCESS; - err = hsa_amd_agent_iterate_memory_pools( - Agent, addKernArgPool, static_cast(&KernArgPools)); - if (err != HSA_STATUS_SUCCESS) { - DP("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, - "Iterate all memory pools", get_error_string(err)); - return {err, hsa_amd_memory_pool_t{}}; - } - } - - if (KernArgPools.empty()) { - DP("Unable to find any valid kernarg pool\n"); - return {HSA_STATUS_ERROR, hsa_amd_memory_pool_t{}}; - } - - return {HSA_STATUS_SUCCESS, KernArgPools[0]}; -} - hsa_status_t addMemoryPool(hsa_amd_memory_pool_t MemoryPool, void *Data) { std::vector *Result = static_cast *>(Data); @@ -652,15 +560,31 @@ return HSA_STATUS_SUCCESS; } - hsa_status_t setupMemoryPools() { - using namespace std::placeholders; - hsa_status_t Err; - Err = core::collectMemoryPools( - HSAAgents, std::bind(&RTLDeviceInfoTy::addDeviceMemoryPool, this, _1, _2)); - if (Err != HSA_STATUS_SUCCESS) { - DP("HSA error in collecting memory pools for offload devices: %s\n", - get_error_string(Err)); - return Err; + hsa_status_t setupDevicePools(const std::vector &Agents) { + for (int DeviceId = 0; DeviceId < Agents.size(); DeviceId++) { + hsa_status_t Err = hsa::amd_agent_iterate_memory_pools( + Agents[DeviceId], [&](hsa_amd_memory_pool_t MemoryPool) { + bool AllocAllowed = false; + hsa_status_t ErrGetInfo; + std::tie(ErrGetInfo, AllocAllowed) = + core::isValidMemoryPool(MemoryPool); + if (ErrGetInfo != HSA_STATUS_SUCCESS) { + DP("Alloc allowed in memory pool check failed: %s\n", + get_error_string(ErrGetInfo)); + return ErrGetInfo; + } + if (AllocAllowed) { + return addDeviceMemoryPool(MemoryPool, DeviceId); + } + + return HSA_STATUS_SUCCESS; + }); + + if (Err != HSA_STATUS_SUCCESS) { + DP("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, + "Iterate all memory pools", get_error_string(Err)); + return Err; + } } return HSA_STATUS_SUCCESS; } @@ -679,6 +603,11 @@ } } + // We need two fine-grained pools. + // 1. One with kernarg flag set for storing kernel arguments + // 2. Second for host allocations + bool FineGrainedMemoryPoolSet = false; + bool KernArgPoolSet = false; for (const auto &MemoryPool : HostPools) { hsa_status_t Err = HSA_STATUS_SUCCESS; uint32_t GlobalFlags = 0; @@ -699,10 +628,26 @@ if ((GlobalFlags & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED) && Size > 0) { - HostFineGrainedMemoryPool = MemoryPool; + if (GlobalFlags & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT) { + KernArgPool = MemoryPool; + KernArgPoolSet = true; + } else { + HostFineGrainedMemoryPool = MemoryPool; + FineGrainedMemoryPoolSet = true; + } } } + if (!KernArgPoolSet) { + DP("No fine-grained kernarg pool found\n"); + return HSA_STATUS_ERROR; + } + + if (!FineGrainedMemoryPoolSet) { + // only other option is to share the kernarg pool. + FineGrainedMemoryPool = KernArgPool; + } + return HSA_STATUS_SUCCESS; } @@ -772,11 +717,6 @@ } else { DP("There are %d devices supporting HSA.\n", NumberOfDevices); } - std::tie(err, KernArgPool) = core::FindKernargPool(CPUAgents); - if (err != HSA_STATUS_SUCCESS) { - DP("Error when reading memory pools\n"); - return; - } // Init the device info HSAQueues.resize(NumberOfDevices); @@ -794,14 +734,15 @@ DeviceCoarseGrainedMemoryPools.resize(NumberOfDevices); DeviceFineGrainedMemoryPools.resize(NumberOfDevices); - err = setupMemoryPools(); + err = setupDevicePools(HSAAgents); if (err != HSA_STATUS_SUCCESS) { - DP("Error when setting up memory pools"); + DP("Setup for Device Memory Pools failed\n"); return; } + err = setupHostMemoryPools(CPUAgents); if (err != HSA_STATUS_SUCCESS) { - DP("Error when setting up host memory pools"); + DP("Setup for Host Memory Pools failed\n"); return; }