Index: llvm/include/llvm/Frontend/OpenMP/OMPGridValues.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPGridValues.h +++ llvm/include/llvm/Frontend/OpenMP/OMPGridValues.h @@ -86,7 +86,9 @@ GV_Max_Warp_Number, /// The slot size that should be reserved for a working warp. /// (~0u >> (GV_Warp_Size - GV_Warp_Size_Log2)) - GV_Warp_Size_Log2_MaskL + GV_Warp_Size_Log2_MaskL, + /// Total number of vector registers per CU or SM + GV_Total_Vector_Registers }; /// For AMDGPU GPUs @@ -104,7 +106,8 @@ 1024, // GV_Max_WG_Size, 256, // GV_Defaut_WG_Size 1024 / 64, // GV_Max_WG_Size / GV_WarpSize - 63 // GV_Warp_Size_Log2_MaskL + 63, // GV_Warp_Size_Log2_MaskL + 64 * 1024 // GV_Total_Vector_Registers }; /// For Nvidia GPUs @@ -122,7 +125,8 @@ 1024, // GV_Max_WG_Size 128, // GV_Defaut_WG_Size 1024 / 32, // GV_Max_WG_Size / GV_WarpSize - 31 // GV_Warp_Size_Log2_MaskL + 31, // GV_Warp_Size_Log2_MaskL + 32 * 1024 // GV_Total_Vector_Registers }; } // namespace omp Index: openmp/libomptarget/plugins/amdgpu/src/rtl.cpp =================================================================== --- openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -70,6 +70,11 @@ } } +// Heuristic parameters used for kernel launch +// Number of teams per CU to allow scheduling flexibility +static const unsigned DefaultTeamsPerCU = 4; +static const unsigned MinTeamsPerCU = 2; + int print_kernel_trace; // Size of the target call stack struture @@ -350,6 +355,8 @@ llvm::omp::AMDGPUGpuGridValues[llvm::omp::GVIDX::GV_Max_WG_Size]; static const int Default_WG_Size = llvm::omp::AMDGPUGpuGridValues[llvm::omp::GVIDX::GV_Default_WG_Size]; + static const int Total_VGPR_Count = llvm::omp::AMDGPUGpuGridValues + [llvm::omp::GVIDX::GV_Total_Vector_Registers]; using MemcpyFunc = atmi_status_t (*)(hsa_signal_t, void *, const void *, size_t size, hsa_agent_t); @@ -790,7 +797,7 @@ DeviceInfo.EnvNumTeams); } else { char *TeamsPerCUEnvStr = getenv("OMP_TARGET_TEAMS_PER_PROC"); - int TeamsPerCU = 1; // default number of teams per CU is 1 + int TeamsPerCU = DefaultTeamsPerCU; if (TeamsPerCUEnvStr) { TeamsPerCU = std::stoi(TeamsPerCUEnvStr); } @@ -813,7 +820,7 @@ RTLDeviceInfoTy::Default_WG_Size); if (DeviceInfo.NumThreads[device_id] > DeviceInfo.ThreadsPerGroup[device_id]) { - DeviceInfo.NumTeams[device_id] = DeviceInfo.ThreadsPerGroup[device_id]; + DeviceInfo.NumThreads[device_id] = DeviceInfo.ThreadsPerGroup[device_id]; DP("Default number of threads exceeds device limit, capping at %d\n", DeviceInfo.ThreadsPerGroup[device_id]); } @@ -1777,7 +1784,23 @@ */ int num_groups = 0; - int threadsPerGroup = RTLDeviceInfoTy::Default_WG_Size; + // Compute the maximum number of VGPRs allowed for a workgroup + int max_vgprs_per_group = RTLDeviceInfoTy::Total_VGPR_Count / MinTeamsPerCU; + + // Compute the max number of threads per group based on the kernel VGPR usage + int threadsPerGroup = max_vgprs_per_group / vgpr_count; + + if (threadsPerGroup > RTLDeviceInfoTy::Default_WG_Size) { + // Cap it beyond the default + threadsPerGroup = RTLDeviceInfoTy::Default_WG_Size; + } else if (threadsPerGroup < RTLDeviceInfoTy::Warp_Size) { + // Lower bound is a wavefront size + threadsPerGroup = RTLDeviceInfoTy::Warp_Size; + } else { + // Round it down to a multiple of wavefront size + threadsPerGroup = (threadsPerGroup / RTLDeviceInfoTy::Warp_Size) * + RTLDeviceInfoTy::Warp_Size; + } getLaunchVals(threadsPerGroup, num_groups, KernelInfo->ConstWGSize, KernelInfo->ExecutionMode, DeviceInfo.EnvTeamLimit,