diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp --- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -731,6 +731,16 @@ return RequiresFlags; } +namespace { +template bool enforce_upper_bound(T *value, T upper) { + bool changed = *value > upper; + if (changed) { + *value = upper; + } + return changed; +} +} // namespace + int32_t __tgt_rtl_init_device(int device_id) { hsa_status_t err; @@ -792,11 +802,13 @@ DeviceInfo.ThreadsPerGroup[device_id] = reinterpret_cast(&grid_max_dim)[0] / DeviceInfo.GroupsPerDevice[device_id]; - if ((DeviceInfo.ThreadsPerGroup[device_id] > - RTLDeviceInfoTy::Max_WG_Size) || - DeviceInfo.ThreadsPerGroup[device_id] == 0) { - DP("Capped thread limit: %d\n", RTLDeviceInfoTy::Max_WG_Size); + + if (DeviceInfo.ThreadsPerGroup[device_id] == 0) { DeviceInfo.ThreadsPerGroup[device_id] = RTLDeviceInfoTy::Max_WG_Size; + DP("Default thread limit: %d\n", RTLDeviceInfoTy::Max_WG_Size); + } else if (enforce_upper_bound(&DeviceInfo.ThreadsPerGroup[device_id], + RTLDeviceInfoTy::Max_WG_Size)) { + DP("Capped thread limit: %d\n", RTLDeviceInfoTy::Max_WG_Size); } else { DP("Using ROCm Queried thread limit: %d\n", DeviceInfo.ThreadsPerGroup[device_id]); @@ -822,9 +834,10 @@ } // Adjust teams to the env variables + if (DeviceInfo.EnvTeamLimit > 0 && - DeviceInfo.GroupsPerDevice[device_id] > DeviceInfo.EnvTeamLimit) { - DeviceInfo.GroupsPerDevice[device_id] = DeviceInfo.EnvTeamLimit; + (enforce_upper_bound(&DeviceInfo.GroupsPerDevice[device_id], + DeviceInfo.EnvTeamLimit))) { DP("Capping max groups per device to OMP_TEAM_LIMIT=%d\n", DeviceInfo.EnvTeamLimit); } @@ -847,8 +860,8 @@ TeamsPerCU, DeviceInfo.ComputeUnits[device_id]); } - if (DeviceInfo.NumTeams[device_id] > DeviceInfo.GroupsPerDevice[device_id]) { - DeviceInfo.NumTeams[device_id] = DeviceInfo.GroupsPerDevice[device_id]; + if (enforce_upper_bound(&DeviceInfo.NumTeams[device_id], + DeviceInfo.GroupsPerDevice[device_id])) { DP("Default number of teams exceeds device limit, capping at %d\n", DeviceInfo.GroupsPerDevice[device_id]); } @@ -857,9 +870,8 @@ DeviceInfo.NumThreads[device_id] = RTLDeviceInfoTy::Default_WG_Size; DP("Default number of threads set according to library's default %d\n", RTLDeviceInfoTy::Default_WG_Size); - if (DeviceInfo.NumThreads[device_id] > - DeviceInfo.ThreadsPerGroup[device_id]) { - DeviceInfo.NumThreads[device_id] = DeviceInfo.ThreadsPerGroup[device_id]; + if (enforce_upper_bound(&DeviceInfo.NumThreads[device_id], + DeviceInfo.ThreadsPerGroup[device_id])) { DP("Default number of threads exceeds device limit, capping at %d\n", DeviceInfo.ThreadsPerGroup[device_id]); }