diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp --- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -1884,14 +1884,21 @@ // Inputs: Max_Teams, Max_WG_Size, Warp_Size, ExecutionMode, // EnvTeamLimit, EnvNumTeams, num_teams, thread_limit, // loop_tripcount. -void getLaunchVals(int &threadsPerGroup, int &num_groups, int ConstWGSize, - int ExecutionMode, int EnvTeamLimit, int EnvNumTeams, - int num_teams, int thread_limit, uint64_t loop_tripcount, - int32_t device_id) { +struct launchVals { + int threadsPerGroup; + int num_groups; +}; + +launchVals getLaunchVals(int ConstWGSize, int ExecutionMode, int EnvTeamLimit, + int EnvNumTeams, int num_teams, int thread_limit, + uint64_t loop_tripcount, int DeviceNumTeams) { + + int threadsPerGroup = RTLDeviceInfoTy::Default_WG_Size; + int num_groups = 0; int Max_Teams = DeviceInfo.EnvMaxTeamsDefault > 0 ? DeviceInfo.EnvMaxTeamsDefault - : DeviceInfo.NumTeams[device_id]; + : DeviceNumTeams; if (Max_Teams > DeviceInfo.HardTeamLimit) Max_Teams = DeviceInfo.HardTeamLimit; @@ -2021,6 +2028,11 @@ } DP("Final %d num_groups and %d threadsPerGroup\n", num_groups, threadsPerGroup); + + launchVals res; + res.threadsPerGroup = threadsPerGroup; + res.num_groups = num_groups; + return res; } static uint64_t acquire_available_packet_id(hsa_queue_t *queue) { @@ -2098,17 +2110,15 @@ /* * Set limit based on ThreadsPerGroup and GroupsPerDevice */ - int num_groups = 0; - - int threadsPerGroup = RTLDeviceInfoTy::Default_WG_Size; - - getLaunchVals(threadsPerGroup, num_groups, KernelInfo->ConstWGSize, - KernelInfo->ExecutionMode, DeviceInfo.EnvTeamLimit, - DeviceInfo.EnvNumTeams, - num_teams, // From run_region arg - thread_limit, // From run_region arg - loop_tripcount, // From run_region arg - KernelInfo->device_id); + launchVals LV = + getLaunchVals(KernelInfo->ConstWGSize, KernelInfo->ExecutionMode, + DeviceInfo.EnvTeamLimit, DeviceInfo.EnvNumTeams, + num_teams, // From run_region arg + thread_limit, // From run_region arg + loop_tripcount, // From run_region arg + DeviceInfo.NumTeams[KernelInfo->device_id]); + int num_groups = LV.num_groups; + int threadsPerGroup = LV.threadsPerGroup; if (print_kernel_trace >= LAUNCH) { // enum modes are SPMD, GENERIC, NONE 0,1,2