Index: llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h +++ llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h @@ -91,7 +91,18 @@ /// be converted to integer, violate subtarget's specifications, or are not /// compatible with minimum/maximum number of waves limited by flat work group /// size, register usage, and/or lds usage. - std::pair getWavesPerEU(const Function &F) const; + std::pair getWavesPerEU(const Function &F) const { + // Default/requested minimum/maximum flat work group sizes. + std::pair FlatWorkGroupSizes = getFlatWorkGroupSizes(F); + return getWavesPerEU(F, FlatWorkGroupSizes); + } + + /// Overload which uses the specified values for the flat work group sizes, + /// rather than querying the function itself. \p FlatWorkGroupSizes Should + /// correspond to the function's value for getFlatWorkGroupSizes. + std::pair + getWavesPerEU(const Function &F, + std::pair FlatWorkGroupSizes) const; /// Return the amount of LDS that can be used that will not restrict the /// occupancy lower than WaveCount. Index: llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp +++ llvm/lib/Target/AMDGPU/AMDGPUSubtarget.cpp @@ -533,13 +533,10 @@ } std::pair AMDGPUSubtarget::getWavesPerEU( - const Function &F) const { + const Function &F, std::pair FlatWorkGroupSizes) const { // Default minimum/maximum number of waves per execution unit. std::pair Default(1, getMaxWavesPerEU()); - // Default/requested minimum/maximum flat work group sizes. - std::pair FlatWorkGroupSizes = getFlatWorkGroupSizes(F); - // If minimum/maximum flat work group sizes were explicitly requested using // "amdgpu-flat-work-group-size" attribute, then set default minimum/maximum // number of waves per execution unit to values implied by requested