diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp --- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -708,6 +708,17 @@ return RequiresFlags; } +namespace { +template bool enforce_upper_bound(T *value, T upper) { + if (*value > upper) { + *value = upper; + return true; + } else { + return false; + } +} +} // namespace + int32_t __tgt_rtl_init_device(int device_id) { hsa_status_t err; @@ -769,11 +780,13 @@ DeviceInfo.ThreadsPerGroup[device_id] = reinterpret_cast(&grid_max_dim)[0] / DeviceInfo.GroupsPerDevice[device_id]; - if ((DeviceInfo.ThreadsPerGroup[device_id] > - RTLDeviceInfoTy::Max_WG_Size) || - DeviceInfo.ThreadsPerGroup[device_id] == 0) { - DP("Capped thread limit: %d\n", RTLDeviceInfoTy::Max_WG_Size); + + if (DeviceInfo.ThreadsPerGroup[device_id] == 0) { DeviceInfo.ThreadsPerGroup[device_id] = RTLDeviceInfoTy::Max_WG_Size; + DP("Default thread limit: %d\n", RTLDeviceInfoTy::Max_WG_Size); + } else if (enforce_upper_bound(&DeviceInfo.ThreadsPerGroup[device_id], + RTLDeviceInfoTy::Max_WG_Size)) { + DP("Capped thread limit: %d\n", RTLDeviceInfoTy::Max_WG_Size); } else { DP("Using ROCm Queried thread limit: %d\n", DeviceInfo.ThreadsPerGroup[device_id]); @@ -799,9 +812,10 @@ } // Adjust teams to the env variables + if (DeviceInfo.EnvTeamLimit > 0 && - DeviceInfo.GroupsPerDevice[device_id] > DeviceInfo.EnvTeamLimit) { - DeviceInfo.GroupsPerDevice[device_id] = DeviceInfo.EnvTeamLimit; + (enforce_upper_bound(&DeviceInfo.GroupsPerDevice[device_id], + DeviceInfo.EnvTeamLimit))) { DP("Capping max groups per device to OMP_TEAM_LIMIT=%d\n", DeviceInfo.EnvTeamLimit); } @@ -824,8 +838,8 @@ TeamsPerCU, DeviceInfo.ComputeUnits[device_id]); } - if (DeviceInfo.NumTeams[device_id] > DeviceInfo.GroupsPerDevice[device_id]) { - DeviceInfo.NumTeams[device_id] = DeviceInfo.GroupsPerDevice[device_id]; + if (enforce_upper_bound(&DeviceInfo.NumTeams[device_id], + DeviceInfo.GroupsPerDevice[device_id])) { DP("Default number of teams exceeds device limit, capping at %d\n", DeviceInfo.GroupsPerDevice[device_id]); } @@ -834,9 +848,8 @@ DeviceInfo.NumThreads[device_id] = RTLDeviceInfoTy::Default_WG_Size; DP("Default number of threads set according to library's default %d\n", RTLDeviceInfoTy::Default_WG_Size); - if (DeviceInfo.NumThreads[device_id] > - DeviceInfo.ThreadsPerGroup[device_id]) { - DeviceInfo.NumThreads[device_id] = DeviceInfo.ThreadsPerGroup[device_id]; + if (enforce_upper_bound(&DeviceInfo.NumThreads[device_id], + DeviceInfo.ThreadsPerGroup[device_id])) { DP("Default number of threads exceeds device limit, capping at %d\n", DeviceInfo.ThreadsPerGroup[device_id]); }